diff options
Diffstat (limited to 'internal/simd_wrappers_sse.h')
-rw-r--r-- | internal/simd_wrappers_sse.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/internal/simd_wrappers_sse.h b/internal/simd_wrappers_sse.h index 6480b66..3b78cb4 100644 --- a/internal/simd_wrappers_sse.h +++ b/internal/simd_wrappers_sse.h @@ -22,6 +22,7 @@ namespace gemmlowp { using Int32x4 = __m128i; +using Int16x8 = __m128i; using Uint8x16 = __m128i; template <int ScalarCount> @@ -31,6 +32,12 @@ struct RegisterType<std::int32_t, ScalarCount> { }; template <int ScalarCount> +struct RegisterType<std::int16_t, ScalarCount> { + using Type = + typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type; +}; + +template <int ScalarCount> struct RegisterType<std::uint8_t, ScalarCount> { using Type = typename std::conditional< ScalarCount >= 16, Uint8x16, @@ -42,10 +49,18 @@ inline Int32x4 LoadInt32x4(const std::int32_t* src) { return _mm_loadu_si128(reinterpret_cast<const Int32x4*>(src)); } +inline Int32x4 LoadInt16x8(const std::int16_t* src) { + return _mm_loadu_si128(reinterpret_cast<const Int16x8*>(src)); +} + inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) { _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value); } +inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value); +} + inline Uint8x16 LoadUint8x16(const std::uint8_t* src) { return _mm_loadu_si128(reinterpret_cast<const Uint8x16*>(src)); } @@ -116,6 +131,17 @@ struct LoadContiguousImpl<RegBlockInt32<8, 8>> { } }; +template <> +struct LoadContiguousImpl<RegBlockInt16<8, 8>> { + static RegBlockInt16<8, 8> Run(const std::int16_t* src) { + RegBlockInt16<8, 8> result; + for (int i = 0; i < 8; i++) { + result.buf.reg[i] = LoadInt16x8(src + 8 * i); + } + return result; + } +}; + } // end namespace gemmlowp #include "simd_wrappers_common_neon_sse.h" |