diff options
Diffstat (limited to 'standalone/neon-gemm-kernel-benchmark.cc')
-rw-r--r-- | standalone/neon-gemm-kernel-benchmark.cc | 1458 |
1 files changed, 1452 insertions, 6 deletions
diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc index 2a936c1..bff33fb 100644 --- a/standalone/neon-gemm-kernel-benchmark.cc +++ b/standalone/neon-gemm-kernel-benchmark.cc @@ -61,15 +61,30 @@ #include <cassert> #include <cstdint> #include <cstdlib> +#include <cstring> #include <iostream> #include <random> #include <type_traits> -#if !defined __arm__ && !defined __aarch64__ -#error This benchmark assumes ARM (for inline assembly sections). +#if !defined(__arm__) && !defined(__aarch64__) && \ + !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa)) +#error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections). #endif +#if defined(__arm__) || defined(__aarch64__) #include <arm_neon.h> +#endif + +#if defined(__mips) +#include <msa.h> + +// Some convenience macros to hide differences between MIPS32 and MIPS64. +#ifdef __LP64__ +#define GEMMLOWP_MIPS_XADDIU "daddiu" +#else +#define GEMMLOWP_MIPS_XADDIU "addiu" +#endif +#endif // Typically one wants to fit in L1 cache, and GEMM implementations // are carefully optimized to tune their access patterns to that effect. @@ -2501,6 +2516,291 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits { } }; +#ifdef __ARM_FEATURE_DOTPROD +// Kernels utilizing the Armv8.2 Dot Product extension. +// +// The dot product instructions work by taking 4 consecutive 8-bit depth +// values from each operand, multiplying the 4 pairs together and +// accumulating all the results into the corresponding 32-bit accumulator +// lane. As such, the operation is identical to a 32-bit instruction (like +// FMLA used in SGEMM), except that 4 depth values are processed at a time +// instead of 1. + +// Thus, this first kernel is a carbon copy of +// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good +// performance for most processors) below with the opcode (fmla -> udot) and +// types (float32 -> uint8/uint32) changed. +// +// A signed version of this kernel could be produced by replacing "udot" +// with "sdot" - performance should be identical to this udot kernel. +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "mov x0, %[accum_ptr]\n" + "ld1 {v8.4s}, [x0], #16\n" + "ld1 {v16.4s}, [x0], #16\n" + "ld1 {v24.4s}, [x0], #16\n" + "ld1 {v9.4s}, [x0], #16\n" + "ld1 {v17.4s}, [x0], #16\n" + "ld1 {v25.4s}, [x0], #16\n" + "ld1 {v10.4s}, [x0], #16\n" + "ld1 {v18.4s}, [x0], #16\n" + "ld1 {v26.4s}, [x0], #16\n" + "ld1 {v11.4s}, [x0], #16\n" + "ld1 {v19.4s}, [x0], #16\n" + "ld1 {v27.4s}, [x0], #16\n" + "ld1 {v12.4s}, [x0], #16\n" + "ld1 {v20.4s}, [x0], #16\n" + "ld1 {v28.4s}, [x0], #16\n" + "ld1 {v13.4s}, [x0], #16\n" + "ld1 {v21.4s}, [x0], #16\n" + "ld1 {v29.4s}, [x0], #16\n" + "ld1 {v14.4s}, [x0], #16\n" + "ld1 {v22.4s}, [x0], #16\n" + "ld1 {v30.4s}, [x0], #16\n" + "ld1 {v15.4s}, [x0], #16\n" + "ld1 {v23.4s}, [x0], #16\n" + "ld1 {v31.4s}, [x0], #16\n" + + // The start of the loop assumes first Rhs cell is already loaded, so + // do it here for first iteration. + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + + // And the same for the first Lhs cell. + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + + // Start the MACs at the head of the loop - 1st cell from each side + // already loaded. + "udot v8.4s, v2.16b, v0.b[0]\n" + "udot v9.4s, v2.16b, v0.b[1]\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. + "udot v10.4s, v2.16b, v0.b[2]\n" + "udot v11.4s, v2.16b, v0.b[3]\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. + "udot v12.4s, v2.16b, v1.b[0]\n" + "udot v13.4s, v2.16b, v1.b[1]\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. + "udot v14.4s, v2.16b, v1.b[2]\n" + "udot v15.4s, v2.16b, v1.b[3]\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load + // for the next iteration early. + "udot v16.4s, v3.16b, v0.b[0]\n" + "udot v17.4s, v3.16b, v0.b[1]\n" + "udot v18.4s, v3.16b, v0.b[2]\n" + "udot v19.4s, v3.16b, v0.b[3]\n" + "udot v20.4s, v3.16b, v1.b[0]\n" + "udot v21.4s, v3.16b, v1.b[1]\n" + "udot v22.4s, v3.16b, v1.b[2]\n" + "udot v23.4s, v3.16b, v1.b[3]\n" + "udot v24.4s, v4.16b, v0.b[0]\n" + "udot v25.4s, v4.16b, v0.b[1]\n" + "udot v26.4s, v4.16b, v0.b[2]\n" + "udot v27.4s, v4.16b, v0.b[3]\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - + // load for the next iteration early. + "udot v28.4s, v4.16b, v1.b[0]\n" + "udot v29.4s, v4.16b, v1.b[1]\n" + + // Loop. Decrement loop index (depth) by 4 as udot processes 4 + // depth values. + "subs %w[depth], %w[depth], #4\n" + "udot v30.4s, v4.16b, v1.b[2]\n" + "udot v31.4s, v4.16b, v1.b[3]\n" + + "bne " GEMMLOWP_LABEL_LOOP + "b\n" + + // Store accumulators + "mov x0, %[accum_ptr]\n" + "st1 {v8.16b}, [x0], #16\n" + "st1 {v16.16b}, [x0], #16\n" + "st1 {v24.16b}, [x0], #16\n" + "st1 {v9.16b}, [x0], #16\n" + "st1 {v17.16b}, [x0], #16\n" + "st1 {v25.16b}, [x0], #16\n" + "st1 {v10.16b}, [x0], #16\n" + "st1 {v18.16b}, [x0], #16\n" + "st1 {v26.16b}, [x0], #16\n" + "st1 {v11.16b}, [x0], #16\n" + "st1 {v19.16b}, [x0], #16\n" + "st1 {v27.16b}, [x0], #16\n" + "st1 {v12.16b}, [x0], #16\n" + "st1 {v20.16b}, [x0], #16\n" + "st1 {v28.16b}, [x0], #16\n" + "st1 {v13.16b}, [x0], #16\n" + "st1 {v21.16b}, [x0], #16\n" + "st1 {v29.16b}, [x0], #16\n" + "st1 {v14.16b}, [x0], #16\n" + "st1 {v22.16b}, [x0], #16\n" + "st1 {v30.16b}, [x0], #16\n" + "st1 {v15.16b}, [x0], #16\n" + "st1 {v23.16b}, [x0], #16\n" + "st1 {v31.16b}, [x0], #16\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); + } +}; + +// As above, except tuned for Cortex-A55r1. +// +// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1 +// with the names changed. +struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "mov x0, %[accum_ptr]\n" + "ld1 {v8.4s}, [x0], #16\n" + "ld1 {v16.4s}, [x0], #16\n" + "ld1 {v24.4s}, [x0], #16\n" + "ld1 {v9.4s}, [x0], #16\n" + "ld1 {v17.4s}, [x0], #16\n" + "ld1 {v25.4s}, [x0], #16\n" + "ld1 {v10.4s}, [x0], #16\n" + "ld1 {v18.4s}, [x0], #16\n" + "ld1 {v26.4s}, [x0], #16\n" + "ld1 {v11.4s}, [x0], #16\n" + "ld1 {v19.4s}, [x0], #16\n" + "ld1 {v27.4s}, [x0], #16\n" + "ld1 {v12.4s}, [x0], #16\n" + "ld1 {v20.4s}, [x0], #16\n" + "ld1 {v28.4s}, [x0], #16\n" + "ld1 {v13.4s}, [x0], #16\n" + "ld1 {v21.4s}, [x0], #16\n" + "ld1 {v29.4s}, [x0], #16\n" + "ld1 {v14.4s}, [x0], #16\n" + "ld1 {v22.4s}, [x0], #16\n" + "ld1 {v30.4s}, [x0], #16\n" + "ld1 {v15.4s}, [x0], #16\n" + "ld1 {v23.4s}, [x0], #16\n" + "ld1 {v31.4s}, [x0], #16\n" + + // For details on how this kernel works, see the Float32 kernel below. + + "ldr d0, [%[rhs_ptr]]\n" + "ldr x18, [%[rhs_ptr], #8]\n" + + "ldr q2, [%[lhs_ptr]]\n" + "ldr q3, [%[lhs_ptr], #16]\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + + "udot v8.4s, v2.16b, v0.b[0]\n" + "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 + "udot v9.4s, v2.16b, v0.b[1]\n" + "ins v0.d[1], x18\n" // Finish loading v0 + "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure. + "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register + "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure. + "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. + "udot v10.4s, v2.16b, v0.b[2]\n" + "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 + "udot v11.4s, v2.16b, v0.b[3]\n" + "ins v1.d[1], x18\n" // Finish loading v1 + "udot v12.4s, v2.16b, v1.b[0]\n" + "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register + "udot v13.4s, v2.16b, v1.b[1]\n" + "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. + "udot v14.4s, v2.16b, v1.b[2]\n" + + "udot v15.4s, v2.16b, v1.b[3]\n" + "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) + "udot v18.4s, v3.16b, v0.b[2]\n" + "ins v4.d[1], x18\n" // Finish loading v4 + "udot v19.4s, v3.16b, v0.b[3]\n" + "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register + "udot v20.4s, v3.16b, v1.b[0]\n" + "subs %w[depth], %w[depth], #4\n" + "udot v21.4s, v3.16b, v1.b[1]\n" + + "udot v22.4s, v3.16b, v1.b[2]\n" + + "udot v23.4s, v3.16b, v1.b[3]\n" + "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) + "udot v24.4s, v4.16b, v0.b[0]\n" + "ins v2.d[1], x18\n" // Finish loading next v2 + "udot v25.4s, v4.16b, v0.b[1]\n" + "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register + "udot v26.4s, v4.16b, v0.b[2]\n" + + "udot v27.4s, v4.16b, v0.b[3]\n" + "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) + "udot v28.4s, v4.16b, v1.b[0]\n" + "ins v3.d[1], x18\n" // Finish loading next v3 + "udot v29.4s, v4.16b, v1.b[1]\n" + "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register + "udot v30.4s, v4.16b, v1.b[2]\n" + + "udot v31.4s, v4.16b, v1.b[3]\n" + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators + "mov x0, %[accum_ptr]\n" + "st1 {v8.4s}, [x0], #16\n" + "st1 {v16.4s}, [x0], #16\n" + "st1 {v24.4s}, [x0], #16\n" + "st1 {v9.4s}, [x0], #16\n" + "st1 {v17.4s}, [x0], #16\n" + "st1 {v25.4s}, [x0], #16\n" + "st1 {v10.4s}, [x0], #16\n" + "st1 {v18.4s}, [x0], #16\n" + "st1 {v26.4s}, [x0], #16\n" + "st1 {v11.4s}, [x0], #16\n" + "st1 {v19.4s}, [x0], #16\n" + "st1 {v27.4s}, [x0], #16\n" + "st1 {v12.4s}, [x0], #16\n" + "st1 {v20.4s}, [x0], #16\n" + "st1 {v28.4s}, [x0], #16\n" + "st1 {v13.4s}, [x0], #16\n" + "st1 {v21.4s}, [x0], #16\n" + "st1 {v29.4s}, [x0], #16\n" + "st1 {v14.4s}, [x0], #16\n" + "st1 {v22.4s}, [x0], #16\n" + "st1 {v30.4s}, [x0], #16\n" + "st1 {v15.4s}, [x0], #16\n" + "st1 {v23.4s}, [x0], #16\n" + "st1 {v31.4s}, [x0], #16\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); + } +}; +#endif // __ARM_FEATURE_DOTPROD + // We don't actually use int32*int32 in production. This is just an // experiment to help dissociate the effect of integer-vs-float, from the // effect of operands width. @@ -3203,8 +3503,172 @@ struct NEON_64bit_GEMM_Float32_WithScalar_A53 { }; #endif +// Faster kernel contributed by ARM. Tuned for A55r1. +struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 { + typedef float OperandType; + typedef float AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "mov x0, %[accum_ptr]\n" + "ld1 {v8.4s}, [x0], #16\n" + "ld1 {v16.4s}, [x0], #16\n" + "ld1 {v24.4s}, [x0], #16\n" + "ld1 {v9.4s}, [x0], #16\n" + "ld1 {v17.4s}, [x0], #16\n" + "ld1 {v25.4s}, [x0], #16\n" + "ld1 {v10.4s}, [x0], #16\n" + "ld1 {v18.4s}, [x0], #16\n" + "ld1 {v26.4s}, [x0], #16\n" + "ld1 {v11.4s}, [x0], #16\n" + "ld1 {v19.4s}, [x0], #16\n" + "ld1 {v27.4s}, [x0], #16\n" + "ld1 {v12.4s}, [x0], #16\n" + "ld1 {v20.4s}, [x0], #16\n" + "ld1 {v28.4s}, [x0], #16\n" + "ld1 {v13.4s}, [x0], #16\n" + "ld1 {v21.4s}, [x0], #16\n" + "ld1 {v29.4s}, [x0], #16\n" + "ld1 {v14.4s}, [x0], #16\n" + "ld1 {v22.4s}, [x0], #16\n" + "ld1 {v30.4s}, [x0], #16\n" + "ld1 {v15.4s}, [x0], #16\n" + "ld1 {v23.4s}, [x0], #16\n" + "ld1 {v31.4s}, [x0], #16\n" + + // A55r1 requires a hybrid of the A53 and standard approaches. + // + // Like A53, this processor prefers 64-bit loads. + // + // Unlike A53, it is capable of dual-issuing a 64-bit vector load + // (or INS) with a FMLA instruction. + // + // Therefore we aim to issue an FMLA instruction every cycle. + // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a + // scalar 64-bit load and finally an INS to replicate the effect of + // a single 128-bit load. + // + // The loop contains 24 FMLA instructions, and 5 vector registers + // need to be loaded, consuming 15 dual issue slots. This leaves 9 + // dual issue slots. Four of these are used for loop housekeeping + // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left + // over (marked by blank lines). + // + // Choice of x18 to store the upper halves on their way into the + // vector registers is arbitrary. Added to the clobber list so that + // the compiler will make it available. + + + // At the start of the loop, it is assumed that v0 is "half loaded" - + // bottom half in place in d0 and the upper half in x18 ready to + // insert. So set that up here for the first iteration: + "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell + "ldr x18, [%[rhs_ptr], #8]\n" // Upper half + + // v2-v3 should be fully loaded - as it's outside the loop proper it's fine + // to use a 128-bit load here. + "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell + "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell + + GEMMLOWP_LABEL_LOOP + ":\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 + "fmla v9.4s, v2.4s, v0.s[1]\n" + "ins v0.d[1], x18\n" // Finish loading v0 + "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure. + "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register + "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure. + "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. + "fmla v10.4s, v2.4s, v0.s[2]\n" + "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 + "fmla v11.4s, v2.4s, v0.s[3]\n" + "ins v1.d[1], x18\n" // Finish loading v1 + "fmla v12.4s, v2.4s, v1.s[0]\n" + "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register + "fmla v13.4s, v2.4s, v1.s[1]\n" + "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. + "fmla v14.4s, v2.4s, v1.s[2]\n" + + "fmla v15.4s, v2.4s, v1.s[3]\n" + "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) + "fmla v18.4s, v3.4s, v0.s[2]\n" + "ins v4.d[1], x18\n" // Finish loading v4 + "fmla v19.4s, v3.4s, v0.s[3]\n" + "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register + "fmla v20.4s, v3.4s, v1.s[0]\n" + "subs %w[depth], %w[depth], #1\n" + "fmla v21.4s, v3.4s, v1.s[1]\n" + + "fmla v22.4s, v3.4s, v1.s[2]\n" + + "fmla v23.4s, v3.4s, v1.s[3]\n" + "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) + "fmla v24.4s, v4.4s, v0.s[0]\n" + "ins v2.d[1], x18\n" // Finish loading next v2 + "fmla v25.4s, v4.4s, v0.s[1]\n" + "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register + "fmla v26.4s, v4.4s, v0.s[2]\n" + + "fmla v27.4s, v4.4s, v0.s[3]\n" + "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) + "fmla v28.4s, v4.4s, v1.s[0]\n" + "ins v3.d[1], x18\n" // Finish loading next v3 + "fmla v29.4s, v4.4s, v1.s[1]\n" + "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register + "fmla v30.4s, v4.4s, v1.s[2]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators + "mov x0, %[accum_ptr]\n" + "st1 {v8.4s}, [x0], #16\n" + "st1 {v16.4s}, [x0], #16\n" + "st1 {v24.4s}, [x0], #16\n" + "st1 {v9.4s}, [x0], #16\n" + "st1 {v17.4s}, [x0], #16\n" + "st1 {v25.4s}, [x0], #16\n" + "st1 {v10.4s}, [x0], #16\n" + "st1 {v18.4s}, [x0], #16\n" + "st1 {v26.4s}, [x0], #16\n" + "st1 {v11.4s}, [x0], #16\n" + "st1 {v19.4s}, [x0], #16\n" + "st1 {v27.4s}, [x0], #16\n" + "st1 {v12.4s}, [x0], #16\n" + "st1 {v20.4s}, [x0], #16\n" + "st1 {v28.4s}, [x0], #16\n" + "st1 {v13.4s}, [x0], #16\n" + "st1 {v21.4s}, [x0], #16\n" + "st1 {v29.4s}, [x0], #16\n" + "st1 {v14.4s}, [x0], #16\n" + "st1 {v22.4s}, [x0], #16\n" + "st1 {v30.4s}, [x0], #16\n" + "st1 {v15.4s}, [x0], #16\n" + "st1 {v23.4s}, [x0], #16\n" + "st1 {v31.4s}, [x0], #16\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); + } +}; + #endif // __aarch64__ +#if defined(__arm__) || defined(__aarch64__) #ifndef __aarch64__ inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); @@ -3388,6 +3852,974 @@ using NEON_32bit_GEMM_Float32_WithScalar_intrinsics = using NEON_64bit_GEMM_Float32_WithScalar_intrinsics = NEON_GEMM_Float32_WithScalar_intrinsics<2>; +#endif // __arm__ || __aarch64__ + +#ifdef __mips +static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) { + // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c). +#if 0 + return __builtin_msa_maddv_w(a, b, c); +#else + asm volatile("maddv.w %w[a], %w[b], %w[c]\n" + // Outputs + : [a] "+f"(a) + // Inputs + : [b] "f"(b), [c] "f"(c)); + return a; +#endif +} + +// Using 32x32=32 multiplications. +// 20 MSA regs used: +// - 12 accumulators +// - 6 lhs +// - 1 rhs +// - 1 temps/zeroes +// ~55 instructions in the loop. +struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics { + typedef std::uint8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + const v16i8 zeroes = __builtin_msa_ldi_b(0); + v4i32 acc[3][4]; + // Load accumulators. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); + } + } + + while (depth > 0) { + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + v8i16 lhs[6]; + lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); + lhs[1] = + reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, + reinterpret_cast<v16i8>(lhs[0]))); + lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, + reinterpret_cast<v16i8>(lhs[1]))); + lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, + reinterpret_cast<v16i8>(lhs[1]))); + + // Zero-extend 16-bit elements of lhs[] to 32 bits. + lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); + lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); + lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); + lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); + lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); + lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); + + // Depth 0. + for (int j = 0; j < 4; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + } + } + + // Depth 1. + for (int j = 0; j < 4; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + } + } + + lhs_ptr += 24; + rhs_ptr += 8; + depth -= 2; + } + + // Store accumulators. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); + } + } + } +}; + +// Assembly implementation of the above +// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics. +// Using 32x32=32 multiplications. +// 20 MSA regs used: +// - 12 accumulators +// - 6 lhs +// - 1 rhs +// - 1 temps/zeroes +// ~55 instructions in the loop. +struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly { + typedef std::uint8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > + Format; + static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "ld.w $w0, (0*16)(%[accum_ptr])\n" + "ld.w $w4, (1*16)(%[accum_ptr])\n" + "ld.w $w8, (2*16)(%[accum_ptr])\n" + "ld.w $w1, (3*16)(%[accum_ptr])\n" + "ld.w $w5, (4*16)(%[accum_ptr])\n" + "ld.w $w9, (5*16)(%[accum_ptr])\n" + "ld.w $w2, (6*16)(%[accum_ptr])\n" + "ld.w $w6, (7*16)(%[accum_ptr])\n" + "ld.w $w10, (8*16)(%[accum_ptr])\n" + "ld.w $w3, (9*16)(%[accum_ptr])\n" + "ld.w $w7, (10*16)(%[accum_ptr])\n" + "ld.w $w11, (11*16)(%[accum_ptr])\n" + // Set a temp to all zeroes. + "ldi.b $w19, 0\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A half of the 2x4 cell of Rhs is stored in 32bit in w18. + // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17. + // A 12x4 block of accumulators is stored in 32bit in w0-w11. + // + // +------+------+------+------+ + // Rhs |w18[0]|w18[1]|w18[2]|w18[3]| + // +------+------+------+------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+---+ - - - - +------+------+------+------+ + // |w12|w15| | w0 | w1 | w2 | w3 | + // |w12|w15| | w0 | w1 | w2 | w3 | + // |w12|w15| | w0 | w1 | w2 | w3 | + // |w12|w15| | w0 | w1 | w2 | w3 | + // +---+---+ - - - - +------+------+------+------+ + // |w13|w16| | w4 | w5 | w6 | w7 | + // |w13|w16| | w4 | w5 | w6 | w7 | + // |w13|w16| | w4 | w5 | w6 | w7 | + // |w13|w16| | w4 | w5 | w6 | w7 | + // +---+---+ - - - - +------+------+------+------+ + // |w14|w17| | w8 | w9 | w10 | w11 | + // |w14|w17| | w8 | w9 | w10 | w11 | + // |w14|w17| | w8 | w9 | w10 | w11 | + // |w14|w17| | w8 | w9 | w10 | w11 | + // +---+---+ - - - - +------+------+------+------+ + // + // Accumulator + + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + "ld.b $w12, 0(%[lhs_ptr])\n" + "ld.b $w13, 8(%[lhs_ptr])\n" + + // Load 4 bytes of rhs[] for depth 0. + "lbu $a0, 0(%[rhs_ptr])\n" + "lbu $a1, 1(%[rhs_ptr])\n" + "lbu $a2, 2(%[rhs_ptr])\n" + "lbu $a3, 3(%[rhs_ptr])\n" + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ilvr.b $w12, $w19, $w12\n" + "ilvl.b $w14, $w19, $w13\n" + "ilvr.b $w13, $w19, $w13\n" + // Zero-extend 16-bit elements of lhs[] to 32 bits. + "ilvl.h $w15, $w19, $w12\n" + "ilvl.h $w16, $w19, $w13\n" + "ilvl.h $w17, $w19, $w14\n" + "ilvr.h $w12, $w19, $w12\n" + "ilvr.h $w13, $w19, $w13\n" + "ilvr.h $w14, $w19, $w14\n" + + // Depth 0. + "fill.w $w18, $a0\n" + "lbu $a0, 4(%[rhs_ptr])\n" + "maddv.w $w0, $w12, $w18\n" + "maddv.w $w4, $w13, $w18\n" + "maddv.w $w8, $w14, $w18\n" + "fill.w $w18, $a1\n" + "lbu $a1, 5(%[rhs_ptr])\n" + "maddv.w $w1, $w12, $w18\n" + "maddv.w $w5, $w13, $w18\n" + "maddv.w $w9, $w14, $w18\n" + "fill.w $w18, $a2\n" + "lbu $a2, 6(%[rhs_ptr])\n" + "maddv.w $w2, $w12, $w18\n" + "maddv.w $w6, $w13, $w18\n" + "maddv.w $w10, $w14, $w18\n" + "fill.w $w18, $a3\n" + "lbu $a3, 7(%[rhs_ptr])\n" + "maddv.w $w3, $w12, $w18\n" + "maddv.w $w7, $w13, $w18\n" + "maddv.w $w11, $w14, $w18\n" + + // Depth 1. + "fill.w $w18, $a0\n" + "maddv.w $w0, $w15, $w18\n" + "maddv.w $w4, $w16, $w18\n" + "maddv.w $w8, $w17, $w18\n" + "fill.w $w18, $a1\n" + "maddv.w $w1, $w15, $w18\n" + "maddv.w $w5, $w16, $w18\n" + "maddv.w $w9, $w17, $w18\n" + "fill.w $w18, $a2\n" + "maddv.w $w2, $w15, $w18\n" + "maddv.w $w6, $w16, $w18\n" + "maddv.w $w10, $w17, $w18\n" + "fill.w $w18, $a3\n" + "maddv.w $w3, $w15, $w18\n" + "maddv.w $w7, $w16, $w18\n" + "maddv.w $w11, $w17, $w18\n" + + "addiu %[depth], -2\n" + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators. + "st.w $w0, (0*16)(%[accum_ptr])\n" + "st.w $w4, (1*16)(%[accum_ptr])\n" + "st.w $w8, (2*16)(%[accum_ptr])\n" + "st.w $w1, (3*16)(%[accum_ptr])\n" + "st.w $w5, (4*16)(%[accum_ptr])\n" + "st.w $w9, (5*16)(%[accum_ptr])\n" + "st.w $w2, (6*16)(%[accum_ptr])\n" + "st.w $w6, (7*16)(%[accum_ptr])\n" + "st.w $w10, (8*16)(%[accum_ptr])\n" + "st.w $w3, (9*16)(%[accum_ptr])\n" + "st.w $w7, (10*16)(%[accum_ptr])\n" + "st.w $w11, (11*16)(%[accum_ptr])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "memory", + "a0", "a1", "a2", "a3", + "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", + "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", + "$f16", "$f17", "$f18", "$f19"); + } +}; + +// Assembly implementation of the above +// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). +// Using 16x16=32 multiplications. +// 20 MSA regs used: +// - 12 accumulators +// - 3 lhs +// - 4 rhs +// - 1 temps/zeroes +// ~45 instructions in the loop. +struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 { + typedef std::uint8_t OperandType; + typedef std::int32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > + Format; + static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "ld.w $w0, (0*16)(%[accum_ptr])\n" + "ld.w $w4, (1*16)(%[accum_ptr])\n" + "ld.w $w8, (2*16)(%[accum_ptr])\n" + "ld.w $w1, (3*16)(%[accum_ptr])\n" + "ld.w $w5, (4*16)(%[accum_ptr])\n" + "ld.w $w9, (5*16)(%[accum_ptr])\n" + "ld.w $w2, (6*16)(%[accum_ptr])\n" + "ld.w $w6, (7*16)(%[accum_ptr])\n" + "ld.w $w10, (8*16)(%[accum_ptr])\n" + "ld.w $w3, (9*16)(%[accum_ptr])\n" + "ld.w $w7, (10*16)(%[accum_ptr])\n" + "ld.w $w11, (11*16)(%[accum_ptr])\n" + // Set a temp to all zeroes. + "ldi.b $w19, 0\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register + // contains 4 replicas of a pair of elements). + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14. + // A 12x4 block of accumulators is stored in 32bit in w0-w11. + // + // +-----+-----+-----+-----+ + // Rhs | w15 | w16 | w17 | w18 | + // +-----+-----+-----+-----+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +-----+-----+-----+-----+ + // |w12| | w0 | w1 | w2 | w3 | + // |w12| | w0 | w1 | w2 | w3 | + // |w12| | w0 | w1 | w2 | w3 | + // |w12| | w0 | w1 | w2 | w3 | + // +---+ - - - - +-----+-----+-----+-----+ + // |w13| | w4 | w5 | w6 | w7 | + // |w13| | w4 | w5 | w6 | w7 | + // |w13| | w4 | w5 | w6 | w7 | + // |w13| | w4 | w5 | w6 | w7 | + // +---+ - - - - +-----+-----+-----+-----+ + // |w14| | w8 | w9 | w10 | w11 | + // |w14| | w8 | w9 | w10 | w11 | + // |w14| | w8 | w9 | w10 | w11 | + // |w14| | w8 | w9 | w10 | w11 | + // +---+ - - - - +-----+-----+-----+-----+ + // + // Accumulators + + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + "ld.b $w12, 0(%[lhs_ptr])\n" + "ld.b $w13, 8(%[lhs_ptr])\n" + + // Load 4 bytes of rhs[] for depth 0. + "lbu $a0, 0(%[rhs_ptr])\n" + "lbu $a1, 1(%[rhs_ptr])\n" + "lbu $a2, 2(%[rhs_ptr])\n" + "lbu $a3, 3(%[rhs_ptr])\n" + // Load 4 bytes of rhs[] for depth 1. + "lbu $v0, 4(%[rhs_ptr])\n" + "lbu $v1, 5(%[rhs_ptr])\n" + "lbu $t8, 6(%[rhs_ptr])\n" + "lbu $t9, 7(%[rhs_ptr])\n" + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ilvr.b $w12, $w19, $w12\n" + "ilvl.b $w14, $w19, $w13\n" + "ilvr.b $w13, $w19, $w13\n" + // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. + "ilvl.d $w15, $w19, $w12\n" + "ilvl.d $w16, $w19, $w13\n" + "ilvl.d $w17, $w19, $w14\n" + "ilvr.h $w12, $w15, $w12\n" + "ilvr.h $w13, $w16, $w13\n" + "ilvr.h $w14, $w17, $w14\n" + + // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w. + "ins $a0, $v0, 16, 8\n" + "ins $a1, $v1, 16, 8\n" + "ins $a2, $t8, 16, 8\n" + "ins $a3, $t9, 16, 8\n" + // Make 4 replicas of every pair of rhs[] elements. + "fill.w $w15, $a0\n" + "fill.w $w16, $a1\n" + "fill.w $w17, $a2\n" + "fill.w $w18, $a3\n" + + // Depths 0 and 1. + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w12, $w15\n" + "dpadd_u.w $w4, $w13, $w15\n" + "dpadd_u.w $w8, $w14, $w15\n" + "dpadd_u.w $w1, $w12, $w16\n" + "dpadd_u.w $w5, $w13, $w16\n" + "dpadd_u.w $w9, $w14, $w16\n" + "dpadd_u.w $w2, $w12, $w17\n" + "dpadd_u.w $w6, $w13, $w17\n" + "dpadd_u.w $w10, $w14, $w17\n" + "dpadd_u.w $w3, $w12, $w18\n" + "dpadd_u.w $w7, $w13, $w18\n" + "dpadd_u.w $w11, $w14, $w18\n" + + "addiu %[depth], -2\n" + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators. + "st.w $w0, (0*16)(%[accum_ptr])\n" + "st.w $w4, (1*16)(%[accum_ptr])\n" + "st.w $w8, (2*16)(%[accum_ptr])\n" + "st.w $w1, (3*16)(%[accum_ptr])\n" + "st.w $w5, (4*16)(%[accum_ptr])\n" + "st.w $w9, (5*16)(%[accum_ptr])\n" + "st.w $w2, (6*16)(%[accum_ptr])\n" + "st.w $w6, (7*16)(%[accum_ptr])\n" + "st.w $w10, (8*16)(%[accum_ptr])\n" + "st.w $w3, (9*16)(%[accum_ptr])\n" + "st.w $w7, (10*16)(%[accum_ptr])\n" + "st.w $w11, (11*16)(%[accum_ptr])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "memory", + "v0", "v1", + "a0", "a1", "a2", "a3", + "t8", "t9", + "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", + "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", + "$f16", "$f17", "$f18", "$f19"); + } +}; + +// Using 32x32=32 multiplications. +// 32 MSA regs used: +// - 24 accumulators +// - 6 lhs +// - 1 rhs +// - 1 temps/zeroes +// ~95 instructions in the loop. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > + Format; + static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + const v16i8 zeroes = __builtin_msa_ldi_b(0); + v4i32 acc[3][8]; + // Load accumulators. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 8; j++) { + acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); + } + } + + while (depth > 0) { + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + v8i16 lhs[6]; + lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); + lhs[1] = + reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, + reinterpret_cast<v16i8>(lhs[0]))); + lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, + reinterpret_cast<v16i8>(lhs[1]))); + lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, + reinterpret_cast<v16i8>(lhs[1]))); + + // Zero-extend 16-bit elements of lhs[] to 32 bits. + lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); + lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); + lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); + lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); + lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); + lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); + + // Depth 0. + for (int j = 0; j < 4; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + } + } + for (int j = 4; j < 8; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); + } + } + + // Depth 1. + for (int j = 0; j < 4; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + } + } + for (int j = 4; j < 8; j++) { + // Load 1 byte of rhs, making 4 32-bit replicas of it. + v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8])); + // Multiply-add into accumulators. + for (int i = 0; i < 3; i++) { + acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); + } + } + + lhs_ptr += 24; + rhs_ptr += 16; + depth -= 2; + } + + // Store accumulators. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 8; j++) { + __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); + } + } + } +}; + +// Assembly implementation of the above +// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics. +// Using 32x32=32 multiplications. +// 32 MSA regs used: +// - 24 accumulators +// - 6 lhs +// - 1 rhs +// - 1 temps/zeroes +// ~95 instructions in the loop. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > + Format; + static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "ld.w $w0, (0*16)(%[accum_ptr])\n" + "ld.w $w4, (1*16)(%[accum_ptr])\n" + "ld.w $w8, (2*16)(%[accum_ptr])\n" + "ld.w $w1, (3*16)(%[accum_ptr])\n" + "ld.w $w5, (4*16)(%[accum_ptr])\n" + "ld.w $w9, (5*16)(%[accum_ptr])\n" + "ld.w $w2, (6*16)(%[accum_ptr])\n" + "ld.w $w6, (7*16)(%[accum_ptr])\n" + "ld.w $w10, (8*16)(%[accum_ptr])\n" + "ld.w $w3, (9*16)(%[accum_ptr])\n" + "ld.w $w7, (10*16)(%[accum_ptr])\n" + "ld.w $w11, (11*16)(%[accum_ptr])\n" + "ld.w $w12, (12*16)(%[accum_ptr])\n" + "ld.w $w16, (13*16)(%[accum_ptr])\n" + "ld.w $w20, (14*16)(%[accum_ptr])\n" + "ld.w $w13, (15*16)(%[accum_ptr])\n" + "ld.w $w17, (16*16)(%[accum_ptr])\n" + "ld.w $w21, (17*16)(%[accum_ptr])\n" + "ld.w $w14, (18*16)(%[accum_ptr])\n" + "ld.w $w18, (19*16)(%[accum_ptr])\n" + "ld.w $w22, (20*16)(%[accum_ptr])\n" + "ld.w $w15, (21*16)(%[accum_ptr])\n" + "ld.w $w19, (22*16)(%[accum_ptr])\n" + "ld.w $w23, (23*16)(%[accum_ptr])\n" + // Set a temp to all zeroes. + "ldi.b $w31, 0\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30. + // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29. + // A 12x8 block of accumulators is stored in 32bit in w0-w23. + // + // +------+------+------+------+ + // Rhs |w30[0]|w30[1]|w30[2]|w30[3]| + // +------+------+------+------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+---+ - - - - +------+------+------+------+ + // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | + // +---+---+ - - - - +------+------+------+------+ + // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | + // +---+---+ - - - - +------+------+------+------+ + // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| + // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| + // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| + // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| + // +---+---+ - - - - +------+------+------+------+ + // + // Accumulator + + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + "ld.b $w24, 0(%[lhs_ptr])\n" + "ld.b $w25, 8(%[lhs_ptr])\n" + + // Load 4 bytes of rhs[] for the first half of depth 0. + "lbu $a0, 0(%[rhs_ptr])\n" + "lbu $a1, 1(%[rhs_ptr])\n" + "lbu $a2, 2(%[rhs_ptr])\n" + "lbu $a3, 3(%[rhs_ptr])\n" + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ilvr.b $w24, $w31, $w24\n" + "ilvl.b $w26, $w31, $w25\n" + "ilvr.b $w25, $w31, $w25\n" + // Zero-extend 16-bit elements of lhs[] to 32 bits. + "ilvl.h $w27, $w31, $w24\n" + "ilvl.h $w28, $w31, $w25\n" + "ilvl.h $w29, $w31, $w26\n" + "ilvr.h $w24, $w31, $w24\n" + "ilvr.h $w25, $w31, $w25\n" + "ilvr.h $w26, $w31, $w26\n" + + // Depth 0. + "fill.w $w30, $a0\n" + "lbu $a0, 8(%[rhs_ptr])\n" + "maddv.w $w0, $w24, $w30\n" + "maddv.w $w4, $w25, $w30\n" + "maddv.w $w8, $w26, $w30\n" + "fill.w $w30, $a1\n" + "lbu $a1, 9(%[rhs_ptr])\n" + "maddv.w $w1, $w24, $w30\n" + "maddv.w $w5, $w25, $w30\n" + "maddv.w $w9, $w26, $w30\n" + "fill.w $w30, $a2\n" + "lbu $a2, 10(%[rhs_ptr])\n" + "maddv.w $w2, $w24, $w30\n" + "maddv.w $w6, $w25, $w30\n" + "maddv.w $w10, $w26, $w30\n" + "fill.w $w30, $a3\n" + "lbu $a3, 11(%[rhs_ptr])\n" + "maddv.w $w3, $w24, $w30\n" + "maddv.w $w7, $w25, $w30\n" + "maddv.w $w11, $w26, $w30\n" + + "fill.w $w30, $a0\n" + "lbu $a0, 4(%[rhs_ptr])\n" + "maddv.w $w12, $w24, $w30\n" + "maddv.w $w16, $w25, $w30\n" + "maddv.w $w20, $w26, $w30\n" + "fill.w $w30, $a1\n" + "lbu $a1, 5(%[rhs_ptr])\n" + "maddv.w $w13, $w24, $w30\n" + "maddv.w $w17, $w25, $w30\n" + "maddv.w $w21, $w26, $w30\n" + "fill.w $w30, $a2\n" + "lbu $a2, 6(%[rhs_ptr])\n" + "maddv.w $w14, $w24, $w30\n" + "maddv.w $w18, $w25, $w30\n" + "maddv.w $w22, $w26, $w30\n" + "fill.w $w30, $a3\n" + "lbu $a3, 7(%[rhs_ptr])\n" + "maddv.w $w15, $w24, $w30\n" + "maddv.w $w19, $w25, $w30\n" + "maddv.w $w23, $w26, $w30\n" + + // Depth 1. + "fill.w $w30, $a0\n" + "lbu $a0, 12(%[rhs_ptr])\n" + "maddv.w $w0, $w27, $w30\n" + "maddv.w $w4, $w28, $w30\n" + "maddv.w $w8, $w29, $w30\n" + "fill.w $w30, $a1\n" + "lbu $a1, 13(%[rhs_ptr])\n" + "maddv.w $w1, $w27, $w30\n" + "maddv.w $w5, $w28, $w30\n" + "maddv.w $w9, $w29, $w30\n" + "fill.w $w30, $a2\n" + "lbu $a2, 14(%[rhs_ptr])\n" + "maddv.w $w2, $w27, $w30\n" + "maddv.w $w6, $w28, $w30\n" + "maddv.w $w10, $w29, $w30\n" + "fill.w $w30, $a3\n" + "lbu $a3, 15(%[rhs_ptr])\n" + "maddv.w $w3, $w27, $w30\n" + "maddv.w $w7, $w28, $w30\n" + "maddv.w $w11, $w29, $w30\n" + + "fill.w $w30, $a0\n" + "maddv.w $w12, $w27, $w30\n" + "maddv.w $w16, $w28, $w30\n" + "maddv.w $w20, $w29, $w30\n" + "fill.w $w30, $a1\n" + "maddv.w $w13, $w27, $w30\n" + "maddv.w $w17, $w28, $w30\n" + "maddv.w $w21, $w29, $w30\n" + "fill.w $w30, $a2\n" + "maddv.w $w14, $w27, $w30\n" + "maddv.w $w18, $w28, $w30\n" + "maddv.w $w22, $w29, $w30\n" + "fill.w $w30, $a3\n" + "maddv.w $w15, $w27, $w30\n" + "maddv.w $w19, $w28, $w30\n" + "maddv.w $w23, $w29, $w30\n" + + "addiu %[depth], -2\n" + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators. + "st.w $w0, (0*16)(%[accum_ptr])\n" + "st.w $w4, (1*16)(%[accum_ptr])\n" + "st.w $w8, (2*16)(%[accum_ptr])\n" + "st.w $w1, (3*16)(%[accum_ptr])\n" + "st.w $w5, (4*16)(%[accum_ptr])\n" + "st.w $w9, (5*16)(%[accum_ptr])\n" + "st.w $w2, (6*16)(%[accum_ptr])\n" + "st.w $w6, (7*16)(%[accum_ptr])\n" + "st.w $w10, (8*16)(%[accum_ptr])\n" + "st.w $w3, (9*16)(%[accum_ptr])\n" + "st.w $w7, (10*16)(%[accum_ptr])\n" + "st.w $w11, (11*16)(%[accum_ptr])\n" + "st.w $w12, (12*16)(%[accum_ptr])\n" + "st.w $w16, (13*16)(%[accum_ptr])\n" + "st.w $w20, (14*16)(%[accum_ptr])\n" + "st.w $w13, (15*16)(%[accum_ptr])\n" + "st.w $w17, (16*16)(%[accum_ptr])\n" + "st.w $w21, (17*16)(%[accum_ptr])\n" + "st.w $w14, (18*16)(%[accum_ptr])\n" + "st.w $w18, (19*16)(%[accum_ptr])\n" + "st.w $w22, (20*16)(%[accum_ptr])\n" + "st.w $w15, (21*16)(%[accum_ptr])\n" + "st.w $w19, (22*16)(%[accum_ptr])\n" + "st.w $w23, (23*16)(%[accum_ptr])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "memory", + "a0", "a1", "a2", "a3", + "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", + "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", + "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", + "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); + } +}; + +// Assembly implementation of the above +// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). +// Using 16x16=32 multiplications. +// 32 MSA regs used: +// - 24 accumulators +// - 3 lhs +// - 4 rhs +// - 1 temps/zeroes +// ~70 instructions in the loop. +struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { + typedef std::uint8_t OperandType; + typedef std::uint32_t AccumulatorType; + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > + Format; + static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, + AccumulatorType* accum_ptr, int depth) { + asm volatile( + // Load accumulators + "ld.w $w0, (0*16)(%[accum_ptr])\n" + "ld.w $w4, (1*16)(%[accum_ptr])\n" + "ld.w $w8, (2*16)(%[accum_ptr])\n" + "ld.w $w1, (3*16)(%[accum_ptr])\n" + "ld.w $w5, (4*16)(%[accum_ptr])\n" + "ld.w $w9, (5*16)(%[accum_ptr])\n" + "ld.w $w2, (6*16)(%[accum_ptr])\n" + "ld.w $w6, (7*16)(%[accum_ptr])\n" + "ld.w $w10, (8*16)(%[accum_ptr])\n" + "ld.w $w3, (9*16)(%[accum_ptr])\n" + "ld.w $w7, (10*16)(%[accum_ptr])\n" + "ld.w $w11, (11*16)(%[accum_ptr])\n" + "ld.w $w12, (12*16)(%[accum_ptr])\n" + "ld.w $w16, (13*16)(%[accum_ptr])\n" + "ld.w $w20, (14*16)(%[accum_ptr])\n" + "ld.w $w13, (15*16)(%[accum_ptr])\n" + "ld.w $w17, (16*16)(%[accum_ptr])\n" + "ld.w $w21, (17*16)(%[accum_ptr])\n" + "ld.w $w14, (18*16)(%[accum_ptr])\n" + "ld.w $w18, (19*16)(%[accum_ptr])\n" + "ld.w $w22, (20*16)(%[accum_ptr])\n" + "ld.w $w15, (21*16)(%[accum_ptr])\n" + "ld.w $w19, (22*16)(%[accum_ptr])\n" + "ld.w $w23, (23*16)(%[accum_ptr])\n" + // Set a temp to all zeroes. + "ldi.b $w31, 0\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 + // (each register contains 4 replicas of a pair of elements). + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. + // A 12x8 block of accumulators is stored in 32bit in w0-w23. + // + // +------+------+------+------+ + // Rhs |w27 |w28 |w29 |w30 | + // +------+------+------+------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +------+------+------+------+ + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | + // +---+ - - - - +------+------+------+------+ + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | + // +---+ - - - - +------+------+------+------+ + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // |w26| |w8/20 |w9/21 |w10/22|w11/23| + // +---+ - - - - +------+------+------+------+ + // + // Accumulators + + // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. + "ld.b $w24, 0(%[lhs_ptr])\n" + "ld.b $w25, 8(%[lhs_ptr])\n" + + // Load 4 bytes of rhs[] for the first half of depth 0. + "lbu $a0, 0(%[rhs_ptr])\n" + "lbu $a1, 1(%[rhs_ptr])\n" + "lbu $a2, 2(%[rhs_ptr])\n" + "lbu $a3, 3(%[rhs_ptr])\n" + // Load 4 bytes of rhs[] for the first half of depth 1. + "lbu $v0, 4(%[rhs_ptr])\n" + "lbu $v1, 5(%[rhs_ptr])\n" + "lbu $t8, 6(%[rhs_ptr])\n" + "lbu $t9, 7(%[rhs_ptr])\n" + + // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ilvr.b $w24, $w31, $w24\n" + "ilvl.b $w26, $w31, $w25\n" + "ilvr.b $w25, $w31, $w25\n" + // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. + "ilvl.d $w27, $w31, $w24\n" + "ilvl.d $w28, $w31, $w25\n" + "ilvl.d $w29, $w31, $w26\n" + "ilvr.h $w24, $w27, $w24\n" + "ilvr.h $w25, $w28, $w25\n" + "ilvr.h $w26, $w29, $w26\n" + + // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w + // (for the first half). + "ins $a0, $v0, 16, 8\n" + "ins $a1, $v1, 16, 8\n" + "ins $a2, $t8, 16, 8\n" + "ins $a3, $t9, 16, 8\n" + // Make 4 replicas of every pair of rhs[] elements. + "fill.w $w27, $a0\n" + "fill.w $w28, $a1\n" + "fill.w $w29, $a2\n" + "fill.w $w30, $a3\n" + + // Load 4 bytes of rhs[] for the second half of depth 0. + "lbu $a0, 8(%[rhs_ptr])\n" + "lbu $a1, 9(%[rhs_ptr])\n" + "lbu $a2, 10(%[rhs_ptr])\n" + "lbu $a3, 11(%[rhs_ptr])\n" + // Load 4 bytes of rhs[] for the second half of depth 1. + "lbu $v0, 12(%[rhs_ptr])\n" + "lbu $v1, 13(%[rhs_ptr])\n" + "lbu $t8, 14(%[rhs_ptr])\n" + "lbu $t9, 15(%[rhs_ptr])\n" + + // First half of depths 0 and 1. + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w24, $w27\n" + "dpadd_u.w $w4, $w25, $w27\n" + "dpadd_u.w $w8, $w26, $w27\n" + "dpadd_u.w $w1, $w24, $w28\n" + "dpadd_u.w $w5, $w25, $w28\n" + "dpadd_u.w $w9, $w26, $w28\n" + "dpadd_u.w $w2, $w24, $w29\n" + "dpadd_u.w $w6, $w25, $w29\n" + "dpadd_u.w $w10, $w26, $w29\n" + "dpadd_u.w $w3, $w24, $w30\n" + "dpadd_u.w $w7, $w25, $w30\n" + "dpadd_u.w $w11, $w26, $w30\n" + + // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w + // (for the second half). + "ins $a0, $v0, 16, 8\n" + "ins $a1, $v1, 16, 8\n" + "ins $a2, $t8, 16, 8\n" + "ins $a3, $t9, 16, 8\n" + // Make 4 replicas of every pair of rhs[] elements. + "fill.w $w27, $a0\n" + "fill.w $w28, $a1\n" + "fill.w $w29, $a2\n" + "fill.w $w30, $a3\n" + + // Second half of depths 0 and 1. + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w12, $w24, $w27\n" + "dpadd_u.w $w16, $w25, $w27\n" + "dpadd_u.w $w20, $w26, $w27\n" + "dpadd_u.w $w13, $w24, $w28\n" + "dpadd_u.w $w17, $w25, $w28\n" + "dpadd_u.w $w21, $w26, $w28\n" + "dpadd_u.w $w14, $w24, $w29\n" + "dpadd_u.w $w18, $w25, $w29\n" + "dpadd_u.w $w22, $w26, $w29\n" + "dpadd_u.w $w15, $w24, $w30\n" + "dpadd_u.w $w19, $w25, $w30\n" + "dpadd_u.w $w23, $w26, $w30\n" + + "addiu %[depth], -2\n" + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" + "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" + + // Store accumulators. + "st.w $w0, (0*16)(%[accum_ptr])\n" + "st.w $w4, (1*16)(%[accum_ptr])\n" + "st.w $w8, (2*16)(%[accum_ptr])\n" + "st.w $w1, (3*16)(%[accum_ptr])\n" + "st.w $w5, (4*16)(%[accum_ptr])\n" + "st.w $w9, (5*16)(%[accum_ptr])\n" + "st.w $w2, (6*16)(%[accum_ptr])\n" + "st.w $w6, (7*16)(%[accum_ptr])\n" + "st.w $w10, (8*16)(%[accum_ptr])\n" + "st.w $w3, (9*16)(%[accum_ptr])\n" + "st.w $w7, (10*16)(%[accum_ptr])\n" + "st.w $w11, (11*16)(%[accum_ptr])\n" + "st.w $w12, (12*16)(%[accum_ptr])\n" + "st.w $w16, (13*16)(%[accum_ptr])\n" + "st.w $w20, (14*16)(%[accum_ptr])\n" + "st.w $w13, (15*16)(%[accum_ptr])\n" + "st.w $w17, (16*16)(%[accum_ptr])\n" + "st.w $w21, (17*16)(%[accum_ptr])\n" + "st.w $w14, (18*16)(%[accum_ptr])\n" + "st.w $w18, (19*16)(%[accum_ptr])\n" + "st.w $w22, (20*16)(%[accum_ptr])\n" + "st.w $w15, (21*16)(%[accum_ptr])\n" + "st.w $w19, (22*16)(%[accum_ptr])\n" + "st.w $w23, (23*16)(%[accum_ptr])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [accum_ptr] "r"(accum_ptr) + : // clobbers + "memory", + "v0", "v1", + "a0", "a1", "a2", "a3", + "t8", "t9", + "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", + "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", + "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", + "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); + } +}; +#endif // __mips // BEGIN code copied from gemmlowp/internal/kernel_reference.h @@ -3451,8 +4883,9 @@ class CacheLineAlignedBuffer { data_ = nullptr; // Adds a few bytes of padding here, because the 64-bit 'A57' kernel // reads one iteration past the end the buffer, causing a crash on iOS. - posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize, - size_ * sizeof(DataType) + 16); + int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize, + size_ * sizeof(DataType) + 16); + (void)res; } ~CacheLineAlignedBuffer() { free(data_); } @@ -3460,7 +4893,7 @@ class CacheLineAlignedBuffer { const DataType* data() const { return data_; } DataType* data() { return data_; } - const std::size_t size() const { return size_; } + std::size_t size() const { return size_; } private: const std::size_t size_; @@ -3726,12 +5159,15 @@ int main() { #endif #ifdef __aarch64__ - BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits); BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics); BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57); +#ifdef __ARM_FEATURE_DOTPROD + BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct); + BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1); +#endif BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar); BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar); BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar); @@ -3740,6 +5176,16 @@ int main() { #ifndef __APPLE__ BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53); #endif + BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1); +#endif + +#ifdef __mips + BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics); + BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly); + BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly); + BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2); #endif return 0; |