diff options
Diffstat (limited to 'internal/simd_wrappers_neon.h')
-rw-r--r-- | internal/simd_wrappers_neon.h | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h index c992b15..2949173 100644 --- a/internal/simd_wrappers_neon.h +++ b/internal/simd_wrappers_neon.h @@ -22,6 +22,8 @@ namespace gemmlowp { using Int32x4 = int32x4_t; +using Int16x4 = int16x4_t; +using Int16x8 = int16x8_t; using Uint8x8 = uint8x8_t; template <int ScalarCount> @@ -31,6 +33,14 @@ struct RegisterType<std::int32_t, ScalarCount> { }; template <int ScalarCount> +struct RegisterType<std::int16_t, ScalarCount> { + using Type = typename std::conditional< + ScalarCount >= 8, Int16x8, + typename std::conditional<ScalarCount >= 4, Int16x4, + std::int16_t>::type>::type; +}; + +template <int ScalarCount> struct RegisterType<std::uint8_t, ScalarCount> { using Type = typename std::conditional< ScalarCount >= 8, Uint8x8, @@ -39,11 +49,21 @@ struct RegisterType<std::uint8_t, ScalarCount> { }; inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); } +inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); } +inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); } inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) { vst1q_s32(dst, value); } +inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) { + vst1_s16(dst, value); +} + +inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { + vst1q_s16(dst, value); +} + template <int Lane> std::int32_t GetLane(Int32x4 value) { return vgetq_lane_s32(value, Lane); @@ -122,6 +142,17 @@ inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { } template <> +struct LoadContiguousImpl<RegBlockInt16<8, 8>> { + static RegBlockInt16<8, 8> Run(const std::int16_t* src) { + RegBlockInt16<8, 8> result; + for (int i = 0; i < 8; i++) { + result.buf.reg[i] = vld1q_s16(src + 8 * i); + } + return result; + } +}; + +template <> struct LoadContiguousImpl<RegBlockUint8<8, 8>> { static RegBlockUint8<8, 8> Run(const std::uint8_t* src) { RegBlockUint8<8, 8> result; |