diff options
author | Frank Barchard <fbarchard@google.com> | 2021-03-11 02:18:15 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2021-03-11 02:18:55 -0800 |
commit | 2f06150c8095b9674a5a48726802c073dc75a1b1 (patch) | |
tree | c47581fe7502b0bcc22fc101386815773f353bfc | |
parent | 1dc9fef1d2c493a99b756733dbc61b717367b86b (diff) | |
download | XNNPACK-2f06150c8095b9674a5a48726802c073dc75a1b1.tar.gz |
xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal GEMM microkernel
Similar math to C16 but load order is different.
PiperOrigin-RevId: 362245222
-rw-r--r-- | BUILD.bazel | 1 | ||||
-rwxr-xr-x | CMakeLists.txt | 1 | ||||
-rw-r--r-- | bench/qs8-gemm-e2e.cc | 20 | ||||
-rw-r--r-- | bench/qs8-gemm.cc | 8 | ||||
-rw-r--r-- | src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S | 278 | ||||
-rw-r--r-- | src/xnnpack/gemm.h | 1 | ||||
-rw-r--r-- | test/qs8-gemm-minmax.cc | 456 | ||||
-rw-r--r-- | test/qs8-gemm-minmax.yaml | 2 |
8 files changed, 767 insertions, 0 deletions
diff --git a/BUILD.bazel b/BUILD.bazel index 7d3b50fa3..c5adf3697 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -3564,6 +3564,7 @@ AARCH64_ASM_UKERNELS = [ "src/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S", "src/qs8-gemm/1x16c4-aarch64-neondot-ld32.S", "src/qs8-gemm/1x16c4-aarch64-neondot-ld64.S", + "src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S", "src/qs8-gemm/2x8c8-aarch64-neon-mull-padal.S", "src/qs8-gemm/2x8c16-aarch64-neon-mlal-padal.S", "src/qs8-gemm/4x16c4-aarch64-neondot-cortex-a55.S", diff --git a/CMakeLists.txt b/CMakeLists.txt index 673da4816..32740e43e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2790,6 +2790,7 @@ SET(XNNPACK_AARCH64_ASM_MICROKERNEL_SRCS src/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S src/qs8-gemm/1x16c4-aarch64-neondot-ld32.S src/qs8-gemm/1x16c4-aarch64-neondot-ld64.S + src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S src/qs8-gemm/2x8c8-aarch64-neon-mull-padal.S src/qs8-gemm/2x8c16-aarch64-neon-mlal-padal.S src/qs8-gemm/4x16c4-aarch64-neondot-cortex-a55.S diff --git a/bench/qs8-gemm-e2e.cc b/bench/qs8-gemm-e2e.cc index 98b1eea6d..367381fbf 100644 --- a/bench/qs8-gemm-e2e.cc +++ b/bench/qs8-gemm-e2e.cc @@ -124,6 +124,24 @@ static void GEMMEnd2EndBenchmark( benchmark::utils::CheckNEONDOT); } + static void qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal, + xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 2 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + static void qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 2 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } static void qs8_gemm_minmax_ukernel_2x8c16__aarch64_neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { GEMMEnd2EndBenchmark(state, model, xnn_qs8_gemm_minmax_ukernel_2x8c16__aarch64_neon_mlal_padal, @@ -141,6 +159,8 @@ static void GEMMEnd2EndBenchmark( BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_cortex_a55) BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld32) BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64) + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal) + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal) BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c16__aarch64_neon_mlal_padal) #endif // XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY diff --git a/bench/qs8-gemm.cc b/bench/qs8-gemm.cc index 639d89010..8a0d8f1d3 100644 --- a/bench/qs8-gemm.cc +++ b/bench/qs8-gemm.cc @@ -559,6 +559,12 @@ static void ruy_st(benchmark::State& state, const char* net) static void qs8_gemm_4x16c4__aarch64_neondot_ld64(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64, 4, 16, 4, 1, benchmark::utils::CheckNEONDOT); } + static void qs8_gemm_2x8c8__aarch64_neon_mull_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal, 2, 8, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_2x8c8__aarch64_neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal, 2, 8, 8, 1, benchmark::utils::CheckNEON); + } static void qs8_gemm_2x8c16__aarch64_neon_mlal_padal(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x8c16__aarch64_neon_mlal_padal, 2, 8, 16, 1, benchmark::utils::CheckNEON); } @@ -568,6 +574,8 @@ static void ruy_st(benchmark::State& state, const char* net) BENCHMARK_GEMM(qs8_gemm_4x16c4__aarch64_neondot_ld32) BENCHMARK_GEMM(qs8_gemm_4x16c4__aarch64_neondot_ld64) BENCHMARK_GEMM(qs8_gemm_4x16c4__aarch64_neondot_cortex_a55) + BENCHMARK_GEMM(qs8_gemm_2x8c8__aarch64_neon_mull_padal) + BENCHMARK_GEMM(qs8_gemm_2x8c8__aarch64_neon_mlal_padal) BENCHMARK_GEMM(qs8_gemm_2x8c16__aarch64_neon_mlal_padal) #endif // XNN_ARCH_ARM64 diff --git a/src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S b/src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S new file mode 100644 index 000000000..8e6bccd70 --- /dev/null +++ b/src/qs8-gemm/2x8c8-aarch64-neon-mlal-padal.S @@ -0,0 +1,278 @@ +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <xnnpack/assembly.h> + +# void xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal( +# size_t mr, x0 +# size_t nc, x1 +# size_t kc, x2 / x0 +# const int8_t* restrict a, x3 +# size_t a_stride, x4 +# const void* restrict w, x5 +# int8_t* restrict c, x6 +# size_t cm_stride, x7 +# size_t cn_stride, [sp] -> x10 +# const union xnn_qs8_gemm_params params) [sp + 8] -> x9 + +# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS. + +# Register usage +# A0 x3 v0 v6 +# A1 x4 v1 v7 +# B x5 v4 v5 v8 v9 +# C0 x7 v16 v18 v20 v22 v24 v26 v28 v30 +# C1 x8 v17 v19 v21 v23 v25 v27 v29 v31 +# temp0 v2 v10 v12 v14 +# temp1 v3 v11 v13 v15 + +BEGIN_FUNCTION xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal + + # Clamp A and C pointers + CMP x0, 2 // if mr < 2 + STP d8, d9, [sp, -64]! + ADD x4, x3, x4 // a1 = a0 + a_stride + STP d10, d11, [sp, 16] + ADD x7, x6, x7 // c1 = c0 + cm_stride + STP d12, d13, [sp, 32] + CSEL x4, x3, x4, LO // a1 = a0 + STP d14, d15, [sp, 48] + ADD x2, x2, 7 // kc = (kc + 7) & ~7 + CSEL x7, x6, x7, LO // c1 = c0 + BIC x2, x2, 7 + + .p2align 3 +0: + # Load initial bias from w into accumulators + SUBS x0, x2, 16 // k = kc - 16 + LDP s16, s18, [x5], 8 + MOV v17.4s, v16.4s + MOV v19.4s, v18.4s + LDP s20, s22, [x5], 8 + MOV v21.4s, v20.4s + MOV v23.4s, v22.4s + LDP s24, s26, [x5], 8 + MOV v25.4s, v24.4s + MOV v27.4s, v26.4s + LDP s28, s30, [x5], 8 + MOV v29.4s, v28.4s + LDP x10, x9, [sp, 64] // cn_stride, params + MOV v31.4s, v30.4s + # Is there at least 16 bytes? + B.LO 3f + + # Main loop - 16 bytes of A + .p2align 3 +1: + LDP d0, d6, [x3], 16 + LDP d4, d5, [x5] + LDP d1, d7, [x4], 16 + LDP d8, d9, [x5, 64] + SMULL v2.8h, v4.8b, v0.8b + SMULL v3.8h, v4.8b, v1.8b + SMULL v10.8h, v5.8b, v0.8b + SMULL v11.8h, v5.8b, v1.8b + LDP d4, d5, [x5, 16] + SMLAL v2.8h, v8.8b, v6.8b + SMLAL v3.8h, v8.8b, v7.8b + SMLAL v10.8h, v9.8b, v6.8b + SMLAL v11.8h, v9.8b, v7.8b + + LDP d8, d9, [x5, 80] + SMULL v12.8h, v4.8b, v0.8b + SADALP v16.4s, v2.8h + SMULL v13.8h, v4.8b, v1.8b + SADALP v17.4s, v3.8h + SMULL v14.8h, v5.8b, v0.8b + SADALP v18.4s, v10.8h + SMULL v15.8h, v5.8b, v1.8b + SADALP v19.4s, v11.8h + LDP d4, d5, [x5, 32] + SMLAL v12.8h, v8.8b, v6.8b + SMLAL v13.8h, v8.8b, v7.8b + SMLAL v14.8h, v9.8b, v6.8b + SMLAL v15.8h, v9.8b, v7.8b + + LDP d8, d9, [x5, 96] + SMULL v2.8h, v4.8b, v0.8b + SADALP v20.4s, v12.8h + SMULL v3.8h, v4.8b, v1.8b + SADALP v21.4s, v13.8h + SMULL v10.8h, v5.8b, v0.8b + SADALP v22.4s, v14.8h + SMULL v11.8h, v5.8b, v1.8b + SADALP v23.4s, v15.8h + LDP d4, d5, [x5, 48] + SMLAL v2.8h, v8.8b, v6.8b + SMLAL v3.8h, v8.8b, v7.8b + SMLAL v10.8h, v9.8b, v6.8b + SMLAL v11.8h, v9.8b, v7.8b + + LDP d8, d9, [x5, 112] + SMULL v12.8h, v4.8b, v0.8b + SADALP v24.4s, v2.8h + SMULL v13.8h, v4.8b, v1.8b + SADALP v25.4s, v3.8h + SMULL v14.8h, v5.8b, v0.8b + SADALP v26.4s, v10.8h + SMULL v15.8h, v5.8b, v1.8b + SADALP v27.4s, v11.8h + SMLAL v12.8h, v8.8b, v6.8b + SMLAL v13.8h, v8.8b, v7.8b + SMLAL v14.8h, v9.8b, v6.8b + SMLAL v15.8h, v9.8b, v7.8b + ADD x5, x5, 128 + + SADALP v28.4s, v12.8h + SADALP v29.4s, v13.8h + SUBS x0, x0, 16 + SADALP v30.4s, v14.8h + SADALP v31.4s, v15.8h + B.HS 1b + + # Is there a remainder?- 8 bytes of A + TBNZ x0, 3, 3f + + .p2align 3 +2: + # Add columns + ADDP v16.4s, v16.4s, v18.4s + ADDP v20.4s, v20.4s, v22.4s + LD1R {v4.4s}, [x9], 4 + ADDP v24.4s, v24.4s, v26.4s + ADDP v28.4s, v28.4s, v30.4s + LD1R {v7.4s}, [x9], 4 + ADDP v17.4s, v17.4s, v19.4s + ADDP v21.4s, v21.4s, v23.4s + ADDP v25.4s, v25.4s, v27.4s + ADDP v29.4s, v29.4s, v31.4s + ADDP v0.4s, v16.4s, v20.4s + ADDP v1.4s, v24.4s, v28.4s + ADDP v2.4s, v17.4s, v21.4s + ADDP v3.4s, v25.4s, v29.4s + + # Apply params - scale, shift, bias and clamp + SQRDMULH v0.4s, v0.4s, v4.4s + SQRDMULH v1.4s, v1.4s, v4.4s + SQRDMULH v2.4s, v2.4s, v4.4s + SQRDMULH v3.4s, v3.4s, v4.4s + CMEQ v4.4s, v7.4s, 0 + LD1R {v5.8h}, [x9], 2 + BIC v6.16b, v0.16b, v4.16b + BIC v16.16b, v1.16b, v4.16b + BIC v17.16b, v2.16b, v4.16b + BIC v4.16b, v3.16b, v4.16b + SSRA v0.4s, v6.4s, 31 + SSRA v1.4s, v16.4s, 31 + SSRA v2.4s, v17.4s, 31 + SSRA v3.4s, v4.4s, 31 + SRSHL v0.4s, v0.4s, v7.4s + SRSHL v1.4s, v1.4s, v7.4s + SRSHL v2.4s, v2.4s, v7.4s + SRSHL v3.4s, v3.4s, v7.4s + SQXTN v0.4h, v0.4s + SQXTN v2.4h, v2.4s + SQXTN2 v0.8h, v1.4s + SQXTN2 v2.8h, v3.4s + SUBS x1, x1, 8 + SQADD v0.8h, v0.8h, v5.8h + SQADD v1.8h, v2.8h, v5.8h + SQXTN v0.8b, v0.8h + SQXTN2 v0.16b, v1.8h + LD1R {v1.16b}, [x9], 1 + LD1R {v2.16b}, [x9] + SMAX v0.16b, v0.16b, v1.16b + SMIN v0.16b, v0.16b, v2.16b + B.LO 4f + + # Store full 2 x 8 + ST1 {v0.8b}, [x6], x10 + SUB x3, x3, x2 // a0 -= kc + ST1 {v0.d}[1], [x7], x10 + SUB x4, x4, x2 // a1 -= kc + B.HI 0b + + # Restore d8-d15 from stack + LDP d14, d15, [sp, 48] + LDP d12, d13, [sp, 32] + LDP d10, d11, [sp, 16] + LDP d8, d9, [sp], 64 + RET + + # Remainder - 8 bytes of A + .p2align 3 +3: + LDR d0, [x3], 8 + LDP d4, d5, [x5] + LDR d1, [x4], 8 + LDP d6, d7, [x5, 16] + SMULL v2.8h, v4.8b, v0.8b + SMULL v3.8h, v4.8b, v1.8b + SMULL v10.8h, v5.8b, v0.8b + SMULL v11.8h, v5.8b, v1.8b + SMULL v12.8h, v6.8b, v0.8b + SADALP v16.4s, v2.8h + SMULL v13.8h, v6.8b, v1.8b + SADALP v17.4s, v3.8h + SMULL v14.8h, v7.8b, v0.8b + SADALP v18.4s, v10.8h + SMULL v15.8h, v7.8b, v1.8b + SADALP v19.4s, v11.8h + LDP d4, d5, [x5, 32] + SMULL v2.8h, v4.8b, v0.8b + SADALP v20.4s, v12.8h + SMULL v3.8h, v4.8b, v1.8b + SADALP v21.4s, v13.8h + SMULL v10.8h, v5.8b, v0.8b + SADALP v22.4s, v14.8h + SMULL v11.8h, v5.8b, v1.8b + SADALP v23.4s, v15.8h + LDP d6, d7, [x5, 48] + SMULL v12.8h, v6.8b, v0.8b + SADALP v24.4s, v2.8h + SMULL v13.8h, v6.8b, v1.8b + SADALP v25.4s, v3.8h + SMULL v14.8h, v7.8b, v0.8b + SADALP v26.4s, v10.8h + SMULL v15.8h, v7.8b, v1.8b + SADALP v27.4s, v11.8h + ADD x5, x5, 64 + SADALP v28.4s, v12.8h + SADALP v29.4s, v13.8h + SADALP v30.4s, v14.8h + SADALP v31.4s, v15.8h + B 2b + + # Store odd width + .p2align 3 +4: + TBZ x1, 2, 5f + STR s0, [x6], 4 + ST1 {v0.s}[2], [x7], 4 + EXT v0.16b, v0.16b, v0.16b, 4 + +5: + TBZ x1, 1, 6f + ST1 {v0.h}[0], [x6], 2 + ST1 {v0.h}[4], [x7], 2 + EXT v0.16b, v0.16b, v0.16b, 2 +6: + TBZ x1, 0, 7f + ST1 {v0.b}[0], [x6] + ST1 {v0.b}[8], [x7] +7: + # Restore d8-d15 from stack + LDP d14, d15, [sp, 48] + LDP d12, d13, [sp, 32] + LDP d10, d11, [sp, 16] + LDP d8, d9, [sp], 64 + RET + +END_FUNCTION xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif + diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 4362c4c61..74c1665bd 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -599,6 +599,7 @@ DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_6x16c4__neo DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c16__aarch64_neon_mlal_padal) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x16c4__aarch64_neondot) diff --git a/test/qs8-gemm-minmax.cc b/test/qs8-gemm-minmax.cc index 4275bb682..87e94ef4d 100644 --- a/test/qs8-gemm-minmax.cc +++ b/test/qs8-gemm-minmax.cc @@ -23,6 +23,462 @@ #if XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 TEST(QS8_GEMM_MINMAX_2X8C8__AARCH64_NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() diff --git a/test/qs8-gemm-minmax.yaml b/test/qs8-gemm-minmax.yaml index 0f1e7e2e4..e85ca934c 100644 --- a/test/qs8-gemm-minmax.yaml +++ b/test/qs8-gemm-minmax.yaml @@ -2,6 +2,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +- name: xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mlal_padal + k-block: 16 - name: xnn_qs8_gemm_minmax_ukernel_2x8c8__aarch64_neon_mull_padal k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_1x8__neon_mlal_lane |