diff options
author | Frank Barchard <fbarchard@google.com> | 2021-03-01 11:05:08 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2021-03-01 11:06:15 -0800 |
commit | 6d8ca7d88ead578661f47ce8c5c6c24b3edc4928 (patch) | |
tree | 2bb651b2e2d8945e16cd5cc39996523012d076eb | |
parent | 02121caa363ea04fda5f79ef073cf4884ab35279 (diff) | |
download | XNNPACK-6d8ca7d88ead578661f47ce8c5c6c24b3edc4928.tar.gz |
Quantized GEMM/IGEMM microkernels bump kc to be a multiple of channels.
Rewind A pointers by KC.
Remove last partial channel of remainder code. Its now handled by main loop.
PiperOrigin-RevId: 360231001
235 files changed, 906 insertions, 1780 deletions
diff --git a/src/qs8-gemm/MRx16c8-avx512skx.c.in b/src/qs8-gemm/MRx16c8-avx512skx.c.in index 973c81cdf..f28a7db4d 100644 --- a/src/qs8-gemm/MRx16c8-avx512skx.c.in +++ b/src/qs8-gemm/MRx16c8-avx512skx.c.in @@ -12,6 +12,7 @@ $assert MR <= 4 #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> $GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else "" @@ -36,6 +37,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -196,10 +198,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx( _mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1)); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - k); + c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/MRx4c2-sse.c.in b/src/qs8-gemm/MRx4c2-sse.c.in index d48eaa577..53eb91841 100644 --- a/src/qs8-gemm/MRx4c2-sse.c.in +++ b/src/qs8-gemm/MRx4c2-sse.c.in @@ -19,6 +19,7 @@ $else: #include <${SSE_HEADER}> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> $LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT] @@ -45,6 +46,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -205,27 +207,6 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}( $else: vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - $if VARIANT == "EXTENDED": - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - $else: - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - $if SSE >= 4: - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - $else: - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - $for M in range(MR): - $if SSE == 5: - vacc${M}x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc${M}x0123); - $else: - vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -325,10 +306,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}( *((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(vout); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); + c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= 4; } else { diff --git a/src/qs8-gemm/MRx4c8-sse.c.in b/src/qs8-gemm/MRx4c8-sse.c.in index 2e8dbe5a4..9e98cfd0c 100644 --- a/src/qs8-gemm/MRx4c8-sse.c.in +++ b/src/qs8-gemm/MRx4c8-sse.c.in @@ -19,6 +19,7 @@ $else: #include <${SSE_HEADER}> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> $LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT] @@ -45,6 +46,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__${ISA}${LOAD_SUFFIX}( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -244,10 +246,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__${ISA}${LOAD_SUFFIX}( *((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(vout); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - k); + c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= 4; } else { diff --git a/src/qs8-gemm/MRx4c8-wasmsimd.c.in b/src/qs8-gemm/MRx4c8-wasmsimd.c.in index dbc36b27c..a10c7ddf9 100644 --- a/src/qs8-gemm/MRx4c8-wasmsimd.c.in +++ b/src/qs8-gemm/MRx4c8-wasmsimd.c.in @@ -10,6 +10,7 @@ $assert MR <= 4 #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> $LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT] @@ -35,6 +36,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX} assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -170,10 +172,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX} *((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M}); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - k); + c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= 4; } else { diff --git a/src/qs8-gemm/MRx8c8-avx2.c.in b/src/qs8-gemm/MRx8c8-avx2.c.in index e40381024..3f17e73ba 100644 --- a/src/qs8-gemm/MRx8c8-avx2.c.in +++ b/src/qs8-gemm/MRx8c8-avx2.c.in @@ -11,6 +11,7 @@ $assert MR <= 4 #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> $GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else "" @@ -35,6 +36,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -169,10 +171,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x8c8__avx2( _mm_storeh_pi((__m64*) c3, _mm_castsi128_ps(vout_hi)); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - k); + c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/MRxNRc4-neondot.c.in b/src/qs8-gemm/MRxNRc4-neondot.c.in index c26b7d3a9..49b43d237 100644 --- a/src/qs8-gemm/MRxNRc4-neondot.c.in +++ b/src/qs8-gemm/MRxNRc4-neondot.c.in @@ -6,12 +6,12 @@ $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" $assert NR % 8 == 0 $assert 8 <= NR <= 16 - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot( @@ -29,7 +29,12 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot( assert(mr <= ${MR}); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -82,7 +87,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a ${MR}x4 block of activations. $for M in range(MR): @@ -108,13 +113,8 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot( vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb4567x${ABC[N:N+4]}, va${M}x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -205,6 +205,9 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot( $for M in range(MR): c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); + $for M in range(MR): + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); + nc -= ${NR}; } else { // Final case where not all of the ${NR} columns fit in the destination. diff --git a/src/qs8-gemm/c16-neon-mlal-padal.c.in b/src/qs8-gemm/c16-neon-mlal-padal.c.in index a6594efc2..ed65c1ed2 100644 --- a/src/qs8-gemm/c16-neon-mlal-padal.c.in +++ b/src/qs8-gemm/c16-neon-mlal-padal.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( @@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -63,7 +64,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( $for N in range(NR): int32x4_t vacc${M}x${N} = vacc0x${N}; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { $for M in range(MR): @@ -191,7 +192,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - k); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= ${NR}; } else { diff --git a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in index ba19f77fc..cbc793cc9 100644 --- a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in +++ b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"}_padal_dup( @@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull" assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -135,16 +136,6 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull" $for N in range(0, NR, 4): const int16x8_t vprod${M}x${ABC[N:N+4]}c2 = vmull_s8(vb${ABC[N:N+4]}c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 2))); vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c2); - - if (k > 6 * sizeof(int8_t)) { - $for N in range(0, NR, 4): - const int8x8_t vb${ABC[N:N+4]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - $for M in range(MR): - $for N in range(0, NR, 4): - const int16x8_t vprod${M}x${ABC[N:N+4]}c3 = vmull_s8(vb${ABC[N:N+4]}c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 3))); - vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c3); - } } } } diff --git a/src/qs8-gemm/c8-neon-mull-padal.c.in b/src/qs8-gemm/c8-neon-mull-padal.c.in index 12edb18f9..205f65046 100644 --- a/src/qs8-gemm/c8-neon-mull-padal.c.in +++ b/src/qs8-gemm/c8-neon-mull-padal.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( @@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; $for M in range(1, MR): @@ -210,7 +211,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); $for M in range(MR): - a${M} = (const int8_t*) ((uintptr_t) a${M} - (kc - k)); + a${M} = (const int8_t*) ((uintptr_t) a${M} - kc); nc -= ${NR}; } else { diff --git a/src/qs8-gemm/gen/12x8c4-minmax-neondot.c b/src/qs8-gemm/gen/12x8c4-minmax-neondot.c index 2dabf60e4..c1665fd52 100644 --- a/src/qs8-gemm/gen/12x8c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/12x8c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot( assert(mr <= 12); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -205,7 +210,7 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 12x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -283,23 +288,8 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot( vacc11x4567 = vdotq_lane_s32(vacc11x4567, vb4567x4567, va11x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - a8 = (const int8_t*) ((uintptr_t) a8 - kc); - a9 = (const int8_t*) ((uintptr_t) a9 - kc); - a10 = (const int8_t*) ((uintptr_t) a10 - kc); - a11 = (const int8_t*) ((uintptr_t) a11 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -466,6 +456,19 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot( c10 = (int8_t*) ((uintptr_t) c10 + cn_stride); c11 = (int8_t*) ((uintptr_t) c11 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); + a8 = (const int8_t*) ((uintptr_t) a8 - kc); + a9 = (const int8_t*) ((uintptr_t) a9 - kc); + a10 = (const int8_t*) ((uintptr_t) a10 - kc); + a11 = (const int8_t*) ((uintptr_t) a11 - kc); + nc -= 8; } else { // Final case where not all of the 8 columns fit in the destination. diff --git a/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c index a3dd9866f..ce7d5d9dd 100644 --- a/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; @@ -57,7 +58,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal( int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -216,7 +217,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c index 43eec2937..7c5f25da6 100644 --- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -238,22 +239,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2); const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c index a862b808a..8447491f8 100644 --- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -150,22 +151,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2); const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/1x16c4-minmax-neondot.c b/src/qs8-gemm/gen/1x16c4-minmax-neondot.c index 86441ddb7..964525e01 100644 --- a/src/qs8-gemm/gen/1x16c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/1x16c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot( assert(mr <= 1); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; @@ -72,7 +77,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 1x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -103,12 +108,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot( vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -153,6 +154,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot( // Advance to the next 16 columns. c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 16; } else { // Final case where not all of the 16 columns fit in the destination. diff --git a/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c index e03b1826a..52ac712a9 100644 --- a/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c +++ b/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c8__avx512skx( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c index 880e8f21d..a345ea062 100644 --- a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -291,7 +292,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c index 36d08bb38..f0c02e1ee 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c index 96aae54d8..1b2860592 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c index 45fee98c3..11eafd3ac 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -147,10 +140,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c index de32f50e1..d1bcbd1c5 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -147,10 +140,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c index 9e517ba9d..bdb39b9f9 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c index 7529af643..11f165f91 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c index 464f22f0c..7473e306a 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -104,15 +106,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128( vacc0x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - } } } } @@ -152,10 +145,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c index 803824884..ece3f815d 100644 --- a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c +++ b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -104,15 +106,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64( vacc0x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - } } } } @@ -152,10 +145,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c index 70c7dd080..6c6dd0d9f 100644 --- a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c +++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -155,10 +149,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c index d5fa2c391..2783338d4 100644 --- a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c +++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -139,10 +133,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c index 40b3ea8ab..6019e9c5a 100644 --- a/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c +++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -155,10 +149,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c index 624ac2b96..b4a2490f0 100644 --- a/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c +++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -97,14 +99,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop( vacc0x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - } } } } @@ -144,10 +138,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c index 87d81ff8b..135dcdb66 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -126,10 +128,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c index 3467c1e01..15bd4fde1 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c index e7b623c77..df8bf837b 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -110,10 +112,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c index 6db42e28b..f2e0e3ac1 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -112,10 +114,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c index d0323ddfb..5f90494a8 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -126,10 +128,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c index c8cd8ad34..7031314c7 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c index ba3e4f197..e14c5d13e 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -119,10 +121,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128( if (nc >= 4) { *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c index 428426a17..3ea0022b9 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -115,10 +117,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64( if (nc >= 4) { *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c index e89b0a27c..5411da359 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -115,10 +117,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c index f2815028b..6bab96c29 100644 --- a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c +++ b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -117,10 +119,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c index 84cd5420c..4d5066791 100644 --- a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c +++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c index 59b547c4e..3dc5959df 100644 --- a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c +++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -108,10 +110,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c index 6e1f9c99d..31d4fc5b9 100644 --- a/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c +++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c index 7cc764350..ffb826b79 100644 --- a/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c +++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -115,10 +117,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd( if (nc >= 4) { *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c index 29349f9f9..d4f4020b9 100644 --- a/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c +++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -113,10 +115,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop( if (nc >= 4) { *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c index 5b7e9721a..5bf0ecac7 100644 --- a/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; @@ -49,7 +50,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal( int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -148,7 +149,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c index 01eea1584..378b1b43a 100644 --- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -154,16 +155,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2); const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c index 26b3be170..9576c96ec 100644 --- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; @@ -106,16 +107,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2); const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/1x8c4-minmax-neondot.c b/src/qs8-gemm/gen/1x8c4-minmax-neondot.c index 5fc71c051..8242be742 100644 --- a/src/qs8-gemm/gen/1x8c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/1x8c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot( assert(mr <= 1); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; @@ -62,7 +67,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 1x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -85,12 +90,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot( vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -127,6 +128,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot( // Advance to the next 8 columns. c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 8; } else { // Final case where not all of the 8 columns fit in the destination. diff --git a/src/qs8-gemm/gen/1x8c8-minmax-avx2.c b/src/qs8-gemm/gen/1x8c8-minmax-avx2.c index 29b2e86f8..adbc9b527 100644 --- a/src/qs8-gemm/gen/1x8c8-minmax-avx2.c +++ b/src/qs8-gemm/gen/1x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2( if (nc >= 8) { _mm_storel_epi64((__m128i*) c0, vout_lo); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c index cdac06c29..5fb7fa695 100644 --- a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -191,7 +192,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c index e039738b2..9d51b83e9 100644 --- a/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c +++ b/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; @@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2( if (nc >= 8) { _mm_storel_epi64((__m128i*) c0, vout_lo); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c index ba85d484f..bab89e52a 100644 --- a/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -79,7 +80,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal( int32x4_t vacc1x14 = vacc0x14; int32x4_t vacc1x15 = vacc0x15; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -349,8 +350,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c index 91f2a9d5c..cbb3d9456 100644 --- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -356,30 +357,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2); const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c index 4943ccd2d..3cc7e4e2a 100644 --- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -218,30 +219,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2); const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c index bfc594b14..1e187792c 100644 --- a/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c +++ b/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -145,12 +147,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx( _mm_storeu_si128((__m128i*) c0, _mm256_castsi256_si128(vout01x0123456789ABCDEF)); _mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1)); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 16; } else { // Prepare mask for valid 8-bit elements (depends on nc). diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c index 7b6d7c603..2c0c62385 100644 --- a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -458,8 +459,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c index 2dea0d22a..d3f8bc073 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -166,12 +168,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c index 52c118d4d..61a134468 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -168,12 +170,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c index 166947bc8..94bd5be07 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -141,12 +143,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c index 139c4cfa0..3765c44d9 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -143,12 +145,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c index 0d0262a08..10307ca20 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -166,12 +168,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c index 08b9f6471..59cf669d8 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -168,12 +170,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c index a77acde9d..2ad04a97a 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -155,12 +157,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128( *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c index c3f548333..a00ad6c8d 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -151,12 +153,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64( *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c index 83ccb6a69..610ecd3b1 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -146,12 +148,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c index e918a7a30..5ef963f81 100644 --- a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c +++ b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -148,12 +150,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c index 38cc15511..5336fcb12 100644 --- a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c +++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -164,12 +166,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c index 0202b2705..c6cca2cd8 100644 --- a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c +++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -139,12 +141,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c index 79f8767ab..521105b64 100644 --- a/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c +++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -164,12 +166,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c index 00a234e53..2a1103147 100644 --- a/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c +++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -151,12 +153,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd( *((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0); *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c index 1969160ab..6e41cc574 100644 --- a/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c +++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -144,12 +146,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c index 269d94d6e..a2d58a294 100644 --- a/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -63,7 +64,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal( int32x4_t vacc1x6 = vacc0x6; int32x4_t vacc1x7 = vacc0x7; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -217,8 +218,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c index e3d3570f1..850d04db0 100644 --- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -218,20 +219,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2); const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c index 5d4e73502..a98d1e9c4 100644 --- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -144,20 +145,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2); const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/2x8c8-minmax-avx2.c b/src/qs8-gemm/gen/2x8c8-minmax-avx2.c index 677a96b05..f44952804 100644 --- a/src/qs8-gemm/gen/2x8c8-minmax-avx2.c +++ b/src/qs8-gemm/gen/2x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -160,12 +162,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2( _mm_storel_epi64((__m128i*) c0, vout_lo); _mm_storel_epi64((__m128i*) c1, vout_hi); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c index ac2af6c87..c98b99c10 100644 --- a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -278,8 +279,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal( c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c index bc7a3b380..e22891634 100644 --- a/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c +++ b/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -156,12 +158,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2( _mm_storel_epi64((__m128i*) c0, vout_lo); _mm_storel_epi64((__m128i*) c1, vout_hi); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c index c15c35ec7..8dd63601e 100644 --- a/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -101,7 +102,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal( int32x4_t vacc2x14 = vacc0x14; int32x4_t vacc2x15 = vacc0x15; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -482,9 +483,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal( c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c index 7750eb308..e3d71f4b4 100644 --- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -474,38 +475,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2); const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c index bb497f575..68671c7db 100644 --- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -286,38 +287,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2); const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c index b5e91ec3a..9970bc3b1 100644 --- a/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c +++ b/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x16c8__avx512skx( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c index 3b9b1ec6b..28ac128bd 100644 --- a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -625,9 +626,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal( c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); - a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c index 709ef30c1..3e40c433f 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -208,14 +210,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c index c01210e04..ec8e70823 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -210,14 +212,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c index fe38e9302..eefed2874 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -174,14 +176,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c index f511ba128..1fac46b6b 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -176,14 +178,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c index 166acff33..d5dedb761 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -208,14 +210,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c index 255c8e2c2..661e5ec8b 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -210,14 +212,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c index 8a037edc5..138f883a5 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -192,14 +194,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128( *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); *((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c index f63c0f0b3..63303a922 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -188,14 +190,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64( *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); *((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c index f3beb05c9..df0e5d5a0 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -179,14 +181,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c index f82a244e6..090bd7332 100644 --- a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c +++ b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -181,14 +183,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c index 1da482fb6..eb77d6e79 100644 --- a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c +++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -206,14 +208,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c index 27ec2777a..167ab3159 100644 --- a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c +++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -172,14 +174,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c index 93fe44bde..d29985525 100644 --- a/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c +++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -206,14 +208,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c index 425833760..7cc2d0115 100644 --- a/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c +++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -188,14 +190,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd( *((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1); *((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c index 0f8932bcd..79cdbf1de 100644 --- a/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c +++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -177,14 +179,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop( *((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1); *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c index 079d6d2be..0b1835649 100644 --- a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -77,7 +78,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal( int32x4_t vacc2x6 = vacc0x6; int32x4_t vacc2x7 = vacc0x7; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -290,9 +291,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal( c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c index 7f095a9d7..4d90e3a42 100644 --- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -282,24 +283,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2); const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c index 2a8d3d909..cdd6df2c4 100644 --- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -182,24 +183,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2); const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/3x8c8-minmax-avx2.c b/src/qs8-gemm/gen/3x8c8-minmax-avx2.c index f4d3c06e7..d703f98a9 100644 --- a/src/qs8-gemm/gen/3x8c8-minmax-avx2.c +++ b/src/qs8-gemm/gen/3x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -195,14 +197,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2( _mm_storel_epi64((__m128i*) c1, vout_hi); _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo)); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c index b1c926f0b..e90507529 100644 --- a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -369,9 +370,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal( c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); - a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c index 370b8d7e7..290c6af15 100644 --- a/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c +++ b/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -191,14 +193,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2( _mm_storel_epi64((__m128i*) c1, vout_hi); _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo)); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + nc -= 8; } else { if (nc & 4) { diff --git a/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c index 9ad5300f9..53cbee31b 100644 --- a/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -123,7 +124,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal( int32x4_t vacc3x14 = vacc0x14; int32x4_t vacc3x15 = vacc0x15; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -615,10 +616,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - a3 = (const int8_t*) ((uintptr_t) a3 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c index 19d1d86d0..c8685b236 100644 --- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -592,46 +593,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2); const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3); - const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c index bcd00726f..ea9270cf7 100644 --- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -354,46 +355,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2); const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3); - const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3); - } } } } diff --git a/src/qs8-gemm/gen/4x16c4-minmax-neondot.c b/src/qs8-gemm/gen/4x16c4-minmax-neondot.c index 9b169938f..719db2270 100644 --- a/src/qs8-gemm/gen/4x16c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/4x16c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot( assert(mr <= 4); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -129,7 +134,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 4x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -187,15 +192,8 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot( vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb4567xCDEF, va3x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -306,6 +304,11 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 16; } else { // Final case where not all of the 16 columns fit in the destination. diff --git a/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c index 0ded25f4e..afa0f41a3 100644 --- a/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c +++ b/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/gemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c8__avx512skx( @@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c index 9121043a8..ed64d8bed 100644 --- a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -792,10 +793,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); - a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k)); - a3 = (const int8_t*) ((uintptr_t) a3 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c index 88b2a448c..fc682f906 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c index bd28fe4f5..4e99dd2d7 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c index 8395f05df..49935f0fb 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -269,16 +256,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c index c06b1f39b..2e985a254 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -269,16 +256,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c index 0d91bcb54..81ed87207 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c index e3c11416f..053a94763 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c index 24f9fde14..860d51b5a 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -185,21 +187,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128( _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123); vacc3x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - vacc1x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123); - vacc2x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123); - vacc3x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123); - } } } } @@ -274,16 +261,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c index a06cd6c9e..5f894889e 100644 --- a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c +++ b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -185,21 +187,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64( _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123); vacc3x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_cvtepi8_epi16(vb3); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - vacc1x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123); - vacc2x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123); - vacc3x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123); - } } } } @@ -274,16 +261,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c index 0b90db831..030105fa5 100644 --- a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c +++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -304,16 +292,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c index 100c93255..d10630bfc 100644 --- a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c +++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -261,16 +249,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c index 4ae78f7d3..637a1ca7d 100644 --- a/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c +++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3( @@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } @@ -304,16 +292,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3( vout = _mm_srli_si128(vout, 4); *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c index b6275df78..471b856c1 100644 --- a/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c +++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop( @@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -178,20 +180,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop( _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123); vacc3x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vxb3 = _mm_load_si128((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - vacc1x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123); - vacc2x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123); - vacc3x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123); - } } } } @@ -266,16 +254,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop( *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2); *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3); - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 4; } else { if (nc & 2) { diff --git a/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c index 76309f4eb..45b6a5e5b 100644 --- a/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -91,7 +92,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal( int32x4_t vacc3x6 = vacc0x6; int32x4_t vacc3x7 = vacc0x7; - // KC loop of 16 with up to 15 remainder + // KC loop of 16 size_t k = 0; while (k < kc) { const int8x16_t va0 = vld1q_s8(a0); a0 += 16; @@ -359,10 +360,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - k); - a1 = (const int8_t*) ((uintptr_t) a1 - k); - a2 = (const int8_t*) ((uintptr_t) a2 - k); - a3 = (const int8_t*) ((uintptr_t) a3 - k); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c index b7403d47d..ddb72c8de 100644 --- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -346,28 +347,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2); const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c index 95d72578a..a7f10cd16 100644 --- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -220,28 +221,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2); const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - } } } } diff --git a/src/qs8-gemm/gen/4x8c4-minmax-neondot.c b/src/qs8-gemm/gen/4x8c4-minmax-neondot.c index ce6a0fb0f..d32b05ecd 100644 --- a/src/qs8-gemm/gen/4x8c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/4x8c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot( assert(mr <= 4); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -101,7 +106,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 4x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -139,15 +144,8 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot( vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -218,6 +216,11 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + nc -= 8; } else { // Final case where not all of the 8 columns fit in the destination. diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c index 08f5386b7..288eac7fa 100644 --- a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal( @@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -456,10 +457,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal( c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); - a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k)); - a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k)); - a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k)); - a3 = (const int8_t*) ((uintptr_t) a3 - (kc - k)); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 8; } else { diff --git a/src/qs8-gemm/gen/6x16c4-minmax-neondot.c b/src/qs8-gemm/gen/6x16c4-minmax-neondot.c index dc3c135fb..3a9277eaf 100644 --- a/src/qs8-gemm/gen/6x16c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/6x16c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot( assert(mr <= 6); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -167,7 +172,7 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 6x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -243,17 +248,8 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot( vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vb4567xCDEF, va5x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -408,6 +404,13 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot( c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + nc -= 16; } else { // Final case where not all of the 16 columns fit in the destination. diff --git a/src/qs8-gemm/gen/6x8c4-minmax-neondot.c b/src/qs8-gemm/gen/6x8c4-minmax-neondot.c index 61ee8a977..b1fc09f9f 100644 --- a/src/qs8-gemm/gen/6x8c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/6x8c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot( assert(mr <= 6); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -127,7 +132,7 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 6x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -175,17 +180,8 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot( vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, va5x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -280,6 +276,13 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot( c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + nc -= 8; } else { // Final case where not all of the 8 columns fit in the destination. diff --git a/src/qs8-gemm/gen/8x16c4-minmax-neondot.c b/src/qs8-gemm/gen/8x16c4-minmax-neondot.c index 666ef0def..93dab4cf8 100644 --- a/src/qs8-gemm/gen/8x16c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/8x16c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot( assert(mr <= 8); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -205,7 +210,7 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 8x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -299,19 +304,8 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot( vacc7xCDEF = vdotq_lane_s32(vacc7xCDEF, vb4567xCDEF, va7x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -510,6 +504,15 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot( c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); + nc -= 16; } else { // Final case where not all of the 16 columns fit in the destination. diff --git a/src/qs8-gemm/gen/8x8c4-minmax-neondot.c b/src/qs8-gemm/gen/8x8c4-minmax-neondot.c index 159363bc2..6dbeb3afc 100644 --- a/src/qs8-gemm/gen/8x8c4-minmax-neondot.c +++ b/src/qs8-gemm/gen/8x8c4-minmax-neondot.c @@ -7,12 +7,12 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. - #include <assert.h> #include <arm_neon.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot( @@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot( assert(mr <= 8); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 4); const int8_t* a0 = a; int8_t* c0 = c; const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); @@ -153,7 +158,7 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 8x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k; @@ -211,19 +216,8 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot( vacc7x4567 = vdotq_lane_s32(vacc7x4567, vb4567x4567, va7x01234567, 1); } } - // End of accumulation loop. The variable `kc` contains the amount by which - // we advanced the `va` pointers, so we rewind by this amount now. - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); // Post-accumulation work - const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); @@ -342,6 +336,15 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot( c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); + nc -= 8; } else { // Final case where not all of the 8 columns fit in the destination. diff --git a/src/qs8-igemm/MRx16c8-avx512skx.c.in b/src/qs8-igemm/MRx16c8-avx512skx.c.in index 0244207b5..af171479c 100644 --- a/src/qs8-igemm/MRx16c8-avx512skx.c.in +++ b/src/qs8-igemm/MRx16c8-avx512skx.c.in @@ -12,6 +12,7 @@ $assert MR <= 4 #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> $GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else "" @@ -38,6 +39,7 @@ void xnn_qs8_igemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); diff --git a/src/qs8-igemm/MRx4c2-sse.c.in b/src/qs8-igemm/MRx4c2-sse.c.in index b917f66c4..d111fa186 100644 --- a/src/qs8-igemm/MRx4c2-sse.c.in +++ b/src/qs8-igemm/MRx4c2-sse.c.in @@ -18,6 +18,7 @@ $else: #include <${SSE_HEADER}> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> $ISA = {2: "sse2", 3: "ssse3", 4: "sse41", 5: "xop"}[SSE] @@ -46,6 +47,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c2__${ISA}_${"ld128" if LD128 else "ld6 assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); @@ -180,20 +182,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c2__${ISA}_${"ld128" if LD128 else "ld6 $else: vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - $for M in range(MR): - $if SSE == 5: - vacc${M}x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc${M}x0123); - $else: - vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/MRx4c8-sse.c.in b/src/qs8-igemm/MRx4c8-sse.c.in index 6a9b5df79..030d1446c 100644 --- a/src/qs8-igemm/MRx4c8-sse.c.in +++ b/src/qs8-igemm/MRx4c8-sse.c.in @@ -18,6 +18,7 @@ $else: #include <${SSE_HEADER}> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> $ISA = {2: "sse2", 3: "ssse3", 4: "sse41", 5: "xop"}[SSE] @@ -46,6 +47,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c8__${ISA}_${"ld128" if LD128 else "ld6 assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); diff --git a/src/qs8-igemm/MRx4c8-wasmsimd.c.in b/src/qs8-igemm/MRx4c8-wasmsimd.c.in index 51dc5cb09..77282fc29 100644 --- a/src/qs8-igemm/MRx4c8-wasmsimd.c.in +++ b/src/qs8-igemm/MRx4c8-wasmsimd.c.in @@ -10,6 +10,7 @@ $assert MR <= 4 #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> $LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT] @@ -39,6 +40,7 @@ void xnn_qs8_igemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); diff --git a/src/qs8-igemm/MRx8c8-avx2.c.in b/src/qs8-igemm/MRx8c8-avx2.c.in index f30a5e8a9..ed1816254 100644 --- a/src/qs8-igemm/MRx8c8-avx2.c.in +++ b/src/qs8-igemm/MRx8c8-avx2.c.in @@ -10,6 +10,7 @@ $assert MR <= 4 #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_${MR}x8c8__avx2( @@ -37,6 +38,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); diff --git a/src/qs8-igemm/MRxNRc4-neondot.c.in b/src/qs8-igemm/MRxNRc4-neondot.c.in index 00b0382c6..e49b634ed 100644 --- a/src/qs8-igemm/MRxNRc4-neondot.c.in +++ b/src/qs8-igemm/MRxNRc4-neondot.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot( @@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); @@ -92,7 +93,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a ${MR}x4 block of activations. $for M in range(MR): diff --git a/src/qs8-igemm/c16-neon-mlal-padal.c.in b/src/qs8-igemm/c16-neon-mlal-padal.c.in index f2202a4f7..1d95ae3b6 100644 --- a/src/qs8-igemm/c16-neon-mlal-padal.c.in +++ b/src/qs8-igemm/c16-neon-mlal-padal.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( @@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); diff --git a/src/qs8-igemm/c2-neon-mull-padal-dup.c.in b/src/qs8-igemm/c2-neon-mull-padal-dup.c.in index 08dc2a038..79e350b3f 100644 --- a/src/qs8-igemm/c2-neon-mull-padal-dup.c.in +++ b/src/qs8-igemm/c2-neon-mull-padal-dup.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"}_padal_dup( @@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); @@ -143,16 +144,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull $for N in range(0, NR, 4): const int16x8_t vprod${M}x${ABC[N:N+4]}c2 = vmull_s8(vb${ABC[N:N+4]}c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 2))); vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c2); - - if (k > 6 * sizeof(int8_t)) { - $for N in range(0, NR, 4): - const int8x8_t vb${ABC[N:N+4]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - $for M in range(MR): - $for N in range(0, NR, 4): - const int16x8_t vprod${M}x${ABC[N:N+4]}c3 = vmull_s8(vb${ABC[N:N+4]}c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 3))); - vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c3); - } } } } diff --git a/src/qs8-igemm/c8-neon-mull-padal.c.in b/src/qs8-igemm/c8-neon-mull-padal.c.in index 26c833689..66b17b1d0 100644 --- a/src/qs8-igemm/c8-neon-mull-padal.c.in +++ b/src/qs8-igemm/c8-neon-mull-padal.c.in @@ -10,8 +10,8 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( @@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; $for M in range(1, MR): int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); @@ -93,9 +94,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { $for M in range(MR): const int8x8_t va${M} = vld1_s8(a${M}); diff --git a/src/qs8-igemm/gen/12x8c4-minmax-neondot.c b/src/qs8-igemm/gen/12x8c4-minmax-neondot.c index d6c5eca85..7a74253bd 100644 --- a/src/qs8-igemm/gen/12x8c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/12x8c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -240,7 +241,7 @@ void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 12x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c index 293022ad5..302733aed 100644 --- a/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c index 21f0a45cb..84cc0e78c 100644 --- a/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -249,22 +250,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup( vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2); const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c index f38446247..9ae6fc711 100644 --- a/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -161,22 +162,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup( vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2); const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/1x16c4-minmax-neondot.c b/src/qs8-igemm/gen/1x16c4-minmax-neondot.c index fb5c18659..a72ae5b02 100644 --- a/src/qs8-igemm/gen/1x16c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/1x16c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; do { @@ -85,7 +86,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 1x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c index 924734f1f..34e6a1910 100644 --- a/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c +++ b/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c8__avx512skx( @@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; const __mmask16 vbias_mask = _cvtu32_mask16(0x1111); diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c index ca428ce16..811cb56de 100644 --- a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { @@ -159,9 +160,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c index 42dc29ea7..cdff4d552 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c index d78c42517..bd6ce42b6 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c index 0d5b1dc9b..312afa93d 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c index 2827d03fe..af2b11611 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c index b9aba9745..5f308dd69 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c index 862a3f6d2..ff8241aa4 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64( vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c index 5882450b5..fe4d34575 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -115,15 +117,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128( vacc0x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - } } } } diff --git a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c index 5aafdf044..7417f8e7a 100644 --- a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c +++ b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -115,15 +117,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64( vacc0x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - } } } } diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c index fe9eb8c03..683721286 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c index 2dd1d3165..fc9152950 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c index 10b9d02d4..66f8bbb67 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c index c171ead40..4031b0820 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c index 561d72f02..b04f5bf0b 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c index f95561dca..6036ab673 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c index 8a104c0b7..9af046ad0 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; const v128_t vzero = wasm_f64x2_splat(0.0); diff --git a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c index 15bc8bb1c..faadba9e3 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; const v128_t vzero = wasm_f64x2_splat(0.0); diff --git a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c index 0686d4e17..aa7eca9e9 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld128( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c index fb46c83fe..1204750d8 100644 --- a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c +++ b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld64( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c index de8a4138d..e81fbe71d 100644 --- a/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c index 0f4c1cf50..ccd894ec1 100644 --- a/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -165,16 +166,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup( vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2); const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c index d8d83677d..c2f45d1a7 100644 --- a/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; do { @@ -117,16 +118,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup( vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2); const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2))); vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/1x8c4-minmax-neondot.c b/src/qs8-igemm/gen/1x8c4-minmax-neondot.c index 07c1ad1f9..ed4260516 100644 --- a/src/qs8-igemm/gen/1x8c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/1x8c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; do { @@ -75,7 +76,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 1x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/1x8c8-minmax-avx2.c b/src/qs8-igemm/gen/1x8c8-minmax-avx2.c index 4eaa40009..4c41935aa 100644 --- a/src/qs8-igemm/gen/1x8c8-minmax-avx2.c +++ b/src/qs8-igemm/gen/1x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c8__avx2( @@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c index 86e477510..d4c86c9d1 100644 --- a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; do { @@ -111,9 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c index 293e5e5e8..db2455bbe 100644 --- a/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x16c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c index 8fa5b3cd4..d35ce3fef 100644 --- a/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -369,30 +370,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup( vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2); const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c index 867c6c0cf..b13ce668d 100644 --- a/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -231,30 +232,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup( vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2); const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c index 347ed788f..328bc55f2 100644 --- a/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c +++ b/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x16c8__avx512skx( @@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c index 2595047b5..c6cd3ef2a 100644 --- a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -233,9 +234,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c index 906861e50..c771d140f 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c index 3a728bd7a..e49603bba 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c index 6d68d363c..e5e18fd63 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c index c7b739279..09d8a2703 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c index 2b1d86acb..d10dc8644 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c index 15e19f9e8..2b97a3c90 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c index 11188d49a..3aa04349a 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c index a85d4d773..d2675e611 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c index 00d6188ea..e58f4254c 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld128( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c index 70458bd04..02bbdf4b2 100644 --- a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c +++ b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld64( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c index c95128a67..bc9ba0ed6 100644 --- a/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c index 478e2c521..f48331d13 100644 --- a/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -231,20 +232,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup( vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2); const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c index 03fa52dae..517813235 100644 --- a/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -157,20 +158,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup( vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2); const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2))); vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/2x8c8-minmax-avx2.c b/src/qs8-igemm/gen/2x8c8-minmax-avx2.c index d3060c342..92918457d 100644 --- a/src/qs8-igemm/gen/2x8c8-minmax-avx2.c +++ b/src/qs8-igemm/gen/2x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x8c8__avx2( @@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c index 9da18a4d3..875fe8eae 100644 --- a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr != 2) { @@ -153,9 +154,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c index 70f26b24a..2b9bfae53 100644 --- a/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x16c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c index 6d304e2de..5fa8f78d4 100644 --- a/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -489,38 +490,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup( vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2); const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c index 53161f3c9..c971f6630 100644 --- a/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -301,38 +302,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup( vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2); const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c index 3fe63e4b4..f9708b792 100644 --- a/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c +++ b/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x16c8__avx512skx( @@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c index 97fd9d725..9a305da2a 100644 --- a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -307,9 +308,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c index dd940c01d..47dba04e7 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c index 4904a1cd8..eaffe043e 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c index 14d7f8c5e..b4c8d56d9 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c index 545038760..6092d915f 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c index 3f3aefadf..f428f767b 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c index 8cca813fc..141be7043 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c index 0c95932f1..f232658ab 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c index 7f22e3d64..1b9efcc74 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c @@ -12,6 +12,7 @@ #include <wasm_simd128.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c index 1760fc04f..4adc8c103 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld128( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c index c215aa93c..6bb8aaa16 100644 --- a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c +++ b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld64( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c index 17ddc358a..626dfa426 100644 --- a/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c index fac3944ec..bb5b54f4e 100644 --- a/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -297,24 +298,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup( vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2); const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c index 3f7f6b31b..0177c1848 100644 --- a/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,24 +198,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup( vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2); const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2))); vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/3x8c8-minmax-avx2.c b/src/qs8-igemm/gen/3x8c8-minmax-avx2.c index 4269041fb..231c884e7 100644 --- a/src/qs8-igemm/gen/3x8c8-minmax-avx2.c +++ b/src/qs8-igemm/gen/3x8c8-minmax-avx2.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x8c8__avx2( @@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__avx2( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c index 0e70a8bcd..5c7810d00 100644 --- a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -195,9 +196,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c index 8368cd23f..df63cae29 100644 --- a/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c index f6d24d228..ef26e6ac3 100644 --- a/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -609,46 +610,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup( vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2); const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3); - const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c index d7a469ba5..6b8b47e6f 100644 --- a/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -371,46 +372,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup( vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2); const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3); - const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3); - const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3); - const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3); - const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3); - } } } } diff --git a/src/qs8-igemm/gen/4x16c4-minmax-neondot.c b/src/qs8-igemm/gen/4x16c4-minmax-neondot.c index a31584133..95dde6d3b 100644 --- a/src/qs8-igemm/gen/4x16c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/4x16c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -148,7 +149,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 4x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c index b16ed2c01..4f724c419 100644 --- a/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c +++ b/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c @@ -13,6 +13,7 @@ #include <xnnpack/igemm.h> #include <xnnpack/intrinsics-polyfill.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c8__avx512skx( @@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__avx512skx( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c index 9430f211c..b03befc19 100644 --- a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -381,9 +382,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c index 89c7807e0..012a594c2 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c index 7b5ce73f6..3d83c87e9 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c @@ -12,6 +12,7 @@ #include <emmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c index 6fc97ccf7..d0a7c2a7c 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c index 4978b7b86..9c34aa337 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c @@ -12,6 +12,7 @@ #include <smmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c index 82cfd1f36..9c1304ffe 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c index 8a2def13f..2d56cdffb 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c @@ -12,6 +12,7 @@ #include <tmmintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64( @@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c index b1b68800a..e6faba23b 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -202,21 +204,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128( _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123); vacc3x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - vacc1x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123); - vacc2x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123); - vacc3x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123); - } } } } diff --git a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c index 7f3c7786a..0253dfa95 100644 --- a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c +++ b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c @@ -17,6 +17,7 @@ #endif #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64( @@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -202,21 +204,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64( _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123); vacc3x0123 = _mm_maddd_epi16( _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123); - - if (k > 6 * sizeof(int8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3)); - - vacc0x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123); - vacc1x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123); - vacc2x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123); - vacc3x0123 = _mm_maddd_epi16( - _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123); - } } } } diff --git a/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c index 948638121..ef871fbd8 100644 --- a/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c +++ b/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x8c16__neon_mlal_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c16__neon_mlal_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 16); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { diff --git a/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c index d1353543a..80aa9ac2a 100644 --- a/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c +++ b/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -363,28 +364,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup( vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2); const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c index 8afdcf0aa..5f635fa27 100644 --- a/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c +++ b/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 2); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -237,28 +238,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup( vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2); const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2))); vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2); - - if (k > 6 * sizeof(int8_t)) { - const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3); - const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3))); - vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3); - const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3); - const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3))); - vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3); - const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3); - const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3))); - vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3); - const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3); - const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3))); - vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3); - } } } } diff --git a/src/qs8-igemm/gen/4x8c4-minmax-neondot.c b/src/qs8-igemm/gen/4x8c4-minmax-neondot.c index b7f687147..5758e21fc 100644 --- a/src/qs8-igemm/gen/4x8c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/4x8c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -120,7 +121,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 4x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c index 594adc938..ae0601936 100644 --- a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 8); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -237,9 +238,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal( k -= 16 * sizeof(int8_t); } // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. if XNN_UNLIKELY(k > 0) { const int8x8_t va0 = vld1_s8(a0); const int8x8_t va1 = vld1_s8(a1); diff --git a/src/qs8-igemm/gen/6x16c4-minmax-neondot.c b/src/qs8-igemm/gen/6x16c4-minmax-neondot.c index b013d4b42..ab885b523 100644 --- a/src/qs8-igemm/gen/6x16c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/6x16c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -190,7 +191,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 6x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/6x8c4-minmax-neondot.c b/src/qs8-igemm/gen/6x8c4-minmax-neondot.c index 13bb06037..b4c3bf21b 100644 --- a/src/qs8-igemm/gen/6x8c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/6x8c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -150,7 +151,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 6x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/8x16c4-minmax-neondot.c b/src/qs8-igemm/gen/8x16c4-minmax-neondot.c index 40f89110c..001ce3388 100644 --- a/src/qs8-igemm/gen/8x16c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/8x16c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -232,7 +233,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 8x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qs8-igemm/gen/8x8c4-minmax-neondot.c b/src/qs8-igemm/gen/8x8c4-minmax-neondot.c index d8f422343..9ce993726 100644 --- a/src/qs8-igemm/gen/8x8c4-minmax-neondot.c +++ b/src/qs8-igemm/gen/8x8c4-minmax-neondot.c @@ -11,8 +11,8 @@ #include <arm_neon.h> -#include <xnnpack/common.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot( @@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot( assert(w != NULL); assert(c != NULL); + kc = round_up_po2(kc, 4); int8_t* c0 = c; int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -180,7 +181,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot( k -= 8 * sizeof(int8_t); } - // Handle up to 7 final positions of `k` + // Handle up to 6 final positions of `k` if XNN_UNLIKELY(k != 0) { // Load a 8x4 block of activations. const int8x8_t va0x01234567 = vld1_s8(a0); diff --git a/src/qu8-gemm/2x4c8-minmax-sse2.c b/src/qu8-gemm/2x4c8-minmax-sse2.c index 371fe74a2..e82204d80 100644 --- a/src/qu8-gemm/2x4c8-minmax-sse2.c +++ b/src/qu8-gemm/2x4c8-minmax-sse2.c @@ -48,7 +48,12 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2( assert(mr <= 2); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 8); const uint8_t* a0 = a; uint8_t* c0 = c; const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); @@ -58,7 +63,6 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2( c1 = c0; } - const size_t kc_stride = round_up_po2(kc, 8); const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point); do { @@ -173,8 +177,8 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2( *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout); *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_epi64(vout, 32)); - a0 = (const uint8_t*) ((uintptr_t) a0 - kc_stride); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); diff --git a/src/qu8-gemm/4x4c2-minmax-sse2.c b/src/qu8-gemm/4x4c2-minmax-sse2.c index 45afdd216..6f687ce00 100644 --- a/src/qu8-gemm/4x4c2-minmax-sse2.c +++ b/src/qu8-gemm/4x4c2-minmax-sse2.c @@ -11,6 +11,7 @@ #include <immintrin.h> #include <xnnpack/gemm.h> +#include <xnnpack/math.h> void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2( @@ -29,7 +30,12 @@ void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2( assert(mr <= 4); assert(nc != 0); assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 2); const uint8_t* a0 = a; uint8_t* c0 = c; const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); @@ -181,21 +187,6 @@ void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2( _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(uint8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - w = (const void*) ((uintptr_t) w + 8); - const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, - _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } diff --git a/src/qu8-igemm/4x4c2-minmax-sse2.c b/src/qu8-igemm/4x4c2-minmax-sse2.c index ba7573c05..cef6f8180 100644 --- a/src/qu8-igemm/4x4c2-minmax-sse2.c +++ b/src/qu8-igemm/4x4c2-minmax-sse2.c @@ -11,6 +11,7 @@ #include <immintrin.h> #include <xnnpack/igemm.h> +#include <xnnpack/math.h> void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2( @@ -33,7 +34,12 @@ void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2( assert(kc != 0); assert(ks != 0); assert(ks % (4 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + kc = round_up_po2(kc, 2); uint8_t* c0 = c; uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { @@ -163,17 +169,6 @@ void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2( vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); - - if (k > 6 * sizeof(uint8_t)) { - const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w); - const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); - w = (void*) ((uintptr_t) w + 8); - - vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); - } } } } |