aboutsummaryrefslogtreecommitdiff
path: root/standalone/neon-gemm-kernel-benchmark.cc
diff options
context:
space:
mode:
Diffstat (limited to 'standalone/neon-gemm-kernel-benchmark.cc')
-rw-r--r--standalone/neon-gemm-kernel-benchmark.cc1458
1 files changed, 1452 insertions, 6 deletions
diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc
index 2a936c1..bff33fb 100644
--- a/standalone/neon-gemm-kernel-benchmark.cc
+++ b/standalone/neon-gemm-kernel-benchmark.cc
@@ -61,15 +61,30 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
+#include <cstring>
#include <iostream>
#include <random>
#include <type_traits>
-#if !defined __arm__ && !defined __aarch64__
-#error This benchmark assumes ARM (for inline assembly sections).
+#if !defined(__arm__) && !defined(__aarch64__) && \
+ !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa))
+#error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections).
#endif
+#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
+#endif
+
+#if defined(__mips)
+#include <msa.h>
+
+// Some convenience macros to hide differences between MIPS32 and MIPS64.
+#ifdef __LP64__
+#define GEMMLOWP_MIPS_XADDIU "daddiu"
+#else
+#define GEMMLOWP_MIPS_XADDIU "addiu"
+#endif
+#endif
// Typically one wants to fit in L1 cache, and GEMM implementations
// are carefully optimized to tune their access patterns to that effect.
@@ -2501,6 +2516,291 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
}
};
+#ifdef __ARM_FEATURE_DOTPROD
+// Kernels utilizing the Armv8.2 Dot Product extension.
+//
+// The dot product instructions work by taking 4 consecutive 8-bit depth
+// values from each operand, multiplying the 4 pairs together and
+// accumulating all the results into the corresponding 32-bit accumulator
+// lane. As such, the operation is identical to a 32-bit instruction (like
+// FMLA used in SGEMM), except that 4 depth values are processed at a time
+// instead of 1.
+
+// Thus, this first kernel is a carbon copy of
+// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
+// performance for most processors) below with the opcode (fmla -> udot) and
+// types (float32 -> uint8/uint32) changed.
+//
+// A signed version of this kernel could be produced by replacing "udot"
+// with "sdot" - performance should be identical to this udot kernel.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // The start of the loop assumes first Rhs cell is already loaded, so
+ // do it here for first iteration.
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+
+ // And the same for the first Lhs cell.
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ // Start the MACs at the head of the loop - 1st cell from each side
+ // already loaded.
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
+ // for the next iteration early.
+ "udot v16.4s, v3.16b, v0.b[0]\n"
+ "udot v17.4s, v3.16b, v0.b[1]\n"
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
+ // load for the next iteration early.
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+
+ // Loop. Decrement loop index (depth) by 4 as udot processes 4
+ // depth values.
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+
+ "bne " GEMMLOWP_LABEL_LOOP
+ "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.16b}, [x0], #16\n"
+ "st1 {v16.16b}, [x0], #16\n"
+ "st1 {v24.16b}, [x0], #16\n"
+ "st1 {v9.16b}, [x0], #16\n"
+ "st1 {v17.16b}, [x0], #16\n"
+ "st1 {v25.16b}, [x0], #16\n"
+ "st1 {v10.16b}, [x0], #16\n"
+ "st1 {v18.16b}, [x0], #16\n"
+ "st1 {v26.16b}, [x0], #16\n"
+ "st1 {v11.16b}, [x0], #16\n"
+ "st1 {v19.16b}, [x0], #16\n"
+ "st1 {v27.16b}, [x0], #16\n"
+ "st1 {v12.16b}, [x0], #16\n"
+ "st1 {v20.16b}, [x0], #16\n"
+ "st1 {v28.16b}, [x0], #16\n"
+ "st1 {v13.16b}, [x0], #16\n"
+ "st1 {v21.16b}, [x0], #16\n"
+ "st1 {v29.16b}, [x0], #16\n"
+ "st1 {v14.16b}, [x0], #16\n"
+ "st1 {v22.16b}, [x0], #16\n"
+ "st1 {v30.16b}, [x0], #16\n"
+ "st1 {v15.16b}, [x0], #16\n"
+ "st1 {v23.16b}, [x0], #16\n"
+ "st1 {v31.16b}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
+ "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
+ "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
+ "v28", "v29", "v30", "v31");
+ }
+};
+
+// As above, except tuned for Cortex-A55r1.
+//
+// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1
+// with the names changed.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // For details on how this kernel works, see the Float32 kernel below.
+
+ "ldr d0, [%[rhs_ptr]]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n"
+
+ "ldr q2, [%[lhs_ptr]]\n"
+ "ldr q3, [%[lhs_ptr], #16]\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+#endif // __ARM_FEATURE_DOTPROD
+
// We don't actually use int32*int32 in production. This is just an
// experiment to help dissociate the effect of integer-vs-float, from the
// effect of operands width.
@@ -3203,8 +3503,172 @@ struct NEON_64bit_GEMM_Float32_WithScalar_A53 {
};
#endif
+// Faster kernel contributed by ARM. Tuned for A55r1.
+struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 {
+ typedef float OperandType;
+ typedef float AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // A55r1 requires a hybrid of the A53 and standard approaches.
+ //
+ // Like A53, this processor prefers 64-bit loads.
+ //
+ // Unlike A53, it is capable of dual-issuing a 64-bit vector load
+ // (or INS) with a FMLA instruction.
+ //
+ // Therefore we aim to issue an FMLA instruction every cycle.
+ // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a
+ // scalar 64-bit load and finally an INS to replicate the effect of
+ // a single 128-bit load.
+ //
+ // The loop contains 24 FMLA instructions, and 5 vector registers
+ // need to be loaded, consuming 15 dual issue slots. This leaves 9
+ // dual issue slots. Four of these are used for loop housekeeping
+ // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left
+ // over (marked by blank lines).
+ //
+ // Choice of x18 to store the upper halves on their way into the
+ // vector registers is arbitrary. Added to the clobber list so that
+ // the compiler will make it available.
+
+
+ // At the start of the loop, it is assumed that v0 is "half loaded" -
+ // bottom half in place in d0 and the upper half in x18 ready to
+ // insert. So set that up here for the first iteration:
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
+ "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
+
+ // v2-v3 should be fully loaded - as it's outside the loop proper it's fine
+ // to use a 128-bit load here.
+ "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell
+ "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "fmla v8.4s, v2.4s, v0.s[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "fmla v9.4s, v2.4s, v0.s[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "fmla v10.4s, v2.4s, v0.s[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "fmla v11.4s, v2.4s, v0.s[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "fmla v12.4s, v2.4s, v1.s[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "fmla v13.4s, v2.4s, v1.s[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "fmla v14.4s, v2.4s, v1.s[2]\n"
+
+ "fmla v15.4s, v2.4s, v1.s[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "fmla v18.4s, v3.4s, v0.s[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "fmla v19.4s, v3.4s, v0.s[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "fmla v20.4s, v3.4s, v1.s[0]\n"
+ "subs %w[depth], %w[depth], #1\n"
+ "fmla v21.4s, v3.4s, v1.s[1]\n"
+
+ "fmla v22.4s, v3.4s, v1.s[2]\n"
+
+ "fmla v23.4s, v3.4s, v1.s[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "fmla v24.4s, v4.4s, v0.s[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "fmla v25.4s, v4.4s, v0.s[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "fmla v26.4s, v4.4s, v0.s[2]\n"
+
+ "fmla v27.4s, v4.4s, v0.s[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "fmla v28.4s, v4.4s, v1.s[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "fmla v29.4s, v4.4s, v1.s[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "fmla v30.4s, v4.4s, v1.s[2]\n"
+
+ "fmla v31.4s, v4.4s, v1.s[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+
#endif // __aarch64__
+#if defined(__arm__) || defined(__aarch64__)
#ifndef __aarch64__
inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
@@ -3388,6 +3852,974 @@ using NEON_32bit_GEMM_Float32_WithScalar_intrinsics =
using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
NEON_GEMM_Float32_WithScalar_intrinsics<2>;
+#endif // __arm__ || __aarch64__
+
+#ifdef __mips
+static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
+ // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
+#if 0
+ return __builtin_msa_maddv_w(a, b, c);
+#else
+ asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
+ // Outputs
+ : [a] "+f"(a)
+ // Inputs
+ : [b] "f"(b), [c] "f"(c));
+ return a;
+#endif
+}
+
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][4];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 8;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2x4 cell of Rhs is stored in 32bit in w18.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +------+------+------+------+
+ // Rhs |w18[0]|w18[1]|w18[2]|w18[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w15, $w19, $w12\n"
+ "ilvl.h $w16, $w19, $w13\n"
+ "ilvl.h $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w19, $w12\n"
+ "ilvr.h $w13, $w19, $w13\n"
+ "ilvr.h $w14, $w19, $w14\n"
+
+ // Depth 0.
+ "fill.w $w18, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w0, $w12, $w18\n"
+ "maddv.w $w4, $w13, $w18\n"
+ "maddv.w $w8, $w14, $w18\n"
+ "fill.w $w18, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w1, $w12, $w18\n"
+ "maddv.w $w5, $w13, $w18\n"
+ "maddv.w $w9, $w14, $w18\n"
+ "fill.w $w18, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w2, $w12, $w18\n"
+ "maddv.w $w6, $w13, $w18\n"
+ "maddv.w $w10, $w14, $w18\n"
+ "fill.w $w18, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w3, $w12, $w18\n"
+ "maddv.w $w7, $w13, $w18\n"
+ "maddv.w $w11, $w14, $w18\n"
+
+ // Depth 1.
+ "fill.w $w18, $a0\n"
+ "maddv.w $w0, $w15, $w18\n"
+ "maddv.w $w4, $w16, $w18\n"
+ "maddv.w $w8, $w17, $w18\n"
+ "fill.w $w18, $a1\n"
+ "maddv.w $w1, $w15, $w18\n"
+ "maddv.w $w5, $w16, $w18\n"
+ "maddv.w $w9, $w17, $w18\n"
+ "fill.w $w18, $a2\n"
+ "maddv.w $w2, $w15, $w18\n"
+ "maddv.w $w6, $w16, $w18\n"
+ "maddv.w $w10, $w17, $w18\n"
+ "fill.w $w18, $a3\n"
+ "maddv.w $w3, $w15, $w18\n"
+ "maddv.w $w7, $w16, $w18\n"
+ "maddv.w $w11, $w17, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~45 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register
+ // contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +-----+-----+-----+-----+
+ // Rhs | w15 | w16 | w17 | w18 |
+ // +-----+-----+-----+-----+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w15, $w19, $w12\n"
+ "ilvl.d $w16, $w19, $w13\n"
+ "ilvl.d $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w15, $w12\n"
+ "ilvr.h $w13, $w16, $w13\n"
+ "ilvr.h $w14, $w17, $w14\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w.
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w15, $a0\n"
+ "fill.w $w16, $a1\n"
+ "fill.w $w17, $a2\n"
+ "fill.w $w18, $a3\n"
+
+ // Depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w12, $w15\n"
+ "dpadd_u.w $w4, $w13, $w15\n"
+ "dpadd_u.w $w8, $w14, $w15\n"
+ "dpadd_u.w $w1, $w12, $w16\n"
+ "dpadd_u.w $w5, $w13, $w16\n"
+ "dpadd_u.w $w9, $w14, $w16\n"
+ "dpadd_u.w $w2, $w12, $w17\n"
+ "dpadd_u.w $w6, $w13, $w17\n"
+ "dpadd_u.w $w10, $w14, $w17\n"
+ "dpadd_u.w $w3, $w12, $w18\n"
+ "dpadd_u.w $w7, $w13, $w18\n"
+ "dpadd_u.w $w11, $w14, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][8];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 16;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w30[0]|w30[1]|w30[2]|w30[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w27, $w31, $w24\n"
+ "ilvl.h $w28, $w31, $w25\n"
+ "ilvl.h $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w31, $w24\n"
+ "ilvr.h $w25, $w31, $w25\n"
+ "ilvr.h $w26, $w31, $w26\n"
+
+ // Depth 0.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "maddv.w $w0, $w24, $w30\n"
+ "maddv.w $w4, $w25, $w30\n"
+ "maddv.w $w8, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "maddv.w $w1, $w24, $w30\n"
+ "maddv.w $w5, $w25, $w30\n"
+ "maddv.w $w9, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "maddv.w $w2, $w24, $w30\n"
+ "maddv.w $w6, $w25, $w30\n"
+ "maddv.w $w10, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ "maddv.w $w3, $w24, $w30\n"
+ "maddv.w $w7, $w25, $w30\n"
+ "maddv.w $w11, $w26, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w12, $w24, $w30\n"
+ "maddv.w $w16, $w25, $w30\n"
+ "maddv.w $w20, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w13, $w24, $w30\n"
+ "maddv.w $w17, $w25, $w30\n"
+ "maddv.w $w21, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w14, $w24, $w30\n"
+ "maddv.w $w18, $w25, $w30\n"
+ "maddv.w $w22, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w15, $w24, $w30\n"
+ "maddv.w $w19, $w25, $w30\n"
+ "maddv.w $w23, $w26, $w30\n"
+
+ // Depth 1.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 12(%[rhs_ptr])\n"
+ "maddv.w $w0, $w27, $w30\n"
+ "maddv.w $w4, $w28, $w30\n"
+ "maddv.w $w8, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 13(%[rhs_ptr])\n"
+ "maddv.w $w1, $w27, $w30\n"
+ "maddv.w $w5, $w28, $w30\n"
+ "maddv.w $w9, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 14(%[rhs_ptr])\n"
+ "maddv.w $w2, $w27, $w30\n"
+ "maddv.w $w6, $w28, $w30\n"
+ "maddv.w $w10, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 15(%[rhs_ptr])\n"
+ "maddv.w $w3, $w27, $w30\n"
+ "maddv.w $w7, $w28, $w30\n"
+ "maddv.w $w11, $w29, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "maddv.w $w12, $w27, $w30\n"
+ "maddv.w $w16, $w28, $w30\n"
+ "maddv.w $w20, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "maddv.w $w13, $w27, $w30\n"
+ "maddv.w $w17, $w28, $w30\n"
+ "maddv.w $w21, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "maddv.w $w14, $w27, $w30\n"
+ "maddv.w $w18, $w28, $w30\n"
+ "maddv.w $w22, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "maddv.w $w15, $w27, $w30\n"
+ "maddv.w $w19, $w28, $w30\n"
+ "maddv.w $w23, $w29, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~70 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w27 |w28 |w29 |w30 |
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+#endif // __mips
// BEGIN code copied from gemmlowp/internal/kernel_reference.h
@@ -3451,8 +4883,9 @@ class CacheLineAlignedBuffer {
data_ = nullptr;
// Adds a few bytes of padding here, because the 64-bit 'A57' kernel
// reads one iteration past the end the buffer, causing a crash on iOS.
- posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
- size_ * sizeof(DataType) + 16);
+ int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
+ size_ * sizeof(DataType) + 16);
+ (void)res;
}
~CacheLineAlignedBuffer() { free(data_); }
@@ -3460,7 +4893,7 @@ class CacheLineAlignedBuffer {
const DataType* data() const { return data_; }
DataType* data() { return data_; }
- const std::size_t size() const { return size_; }
+ std::size_t size() const { return size_; }
private:
const std::size_t size_;
@@ -3726,12 +5159,15 @@ int main() {
#endif
#ifdef __aarch64__
-
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57);
+#ifdef __ARM_FEATURE_DOTPROD
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
+#endif
BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar);
@@ -3740,6 +5176,16 @@ int main() {
#ifndef __APPLE__
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53);
#endif
+ BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1);
+#endif
+
+#ifdef __mips
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2);
#endif
return 0;