diff options
Diffstat (limited to 'internal/simd_wrappers_common_neon_sse.h')
-rw-r--r-- | internal/simd_wrappers_common_neon_sse.h | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h index 3830eb1..694bf99 100644 --- a/internal/simd_wrappers_common_neon_sse.h +++ b/internal/simd_wrappers_common_neon_sse.h @@ -350,6 +350,210 @@ struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { } }; +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + // 4x1 := 4x1 * 1x1 template <> struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { |