aboutsummaryrefslogtreecommitdiff
path: root/internal/simd_wrappers_neon.h
diff options
context:
space:
mode:
Diffstat (limited to 'internal/simd_wrappers_neon.h')
-rw-r--r--internal/simd_wrappers_neon.h31
1 files changed, 31 insertions, 0 deletions
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index c992b15..2949173 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -22,6 +22,8 @@
namespace gemmlowp {
using Int32x4 = int32x4_t;
+using Int16x4 = int16x4_t;
+using Int16x8 = int16x8_t;
using Uint8x8 = uint8x8_t;
template <int ScalarCount>
@@ -31,6 +33,14 @@ struct RegisterType<std::int32_t, ScalarCount> {
};
template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 8, Int16x8,
+ typename std::conditional<ScalarCount >= 4, Int16x4,
+ std::int16_t>::type>::type;
+};
+
+template <int ScalarCount>
struct RegisterType<std::uint8_t, ScalarCount> {
using Type = typename std::conditional<
ScalarCount >= 8, Uint8x8,
@@ -39,11 +49,21 @@ struct RegisterType<std::uint8_t, ScalarCount> {
};
inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
+inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
+inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
vst1q_s32(dst, value);
}
+inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
+ vst1_s16(dst, value);
+}
+
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ vst1q_s16(dst, value);
+}
+
template <int Lane>
std::int32_t GetLane(Int32x4 value) {
return vgetq_lane_s32(value, Lane);
@@ -122,6 +142,17 @@ inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
}
template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = vld1q_s16(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+template <>
struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
RegBlockUint8<8, 8> result;