aboutsummaryrefslogtreecommitdiff
path: root/internal/simd_wrappers_common_neon_sse.h
diff options
context:
space:
mode:
Diffstat (limited to 'internal/simd_wrappers_common_neon_sse.h')
-rw-r--r--internal/simd_wrappers_common_neon_sse.h204
1 files changed, 204 insertions, 0 deletions
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h
index 3830eb1..694bf99 100644
--- a/internal/simd_wrappers_common_neon_sse.h
+++ b/internal/simd_wrappers_common_neon_sse.h
@@ -350,6 +350,210 @@ struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
}
};
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
+ result.buf.reg[2] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
+ }
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
+ }
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+ result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+ result.buf.reg[2] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
+ result.buf.reg[4] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
+ result.buf.reg[5] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
+ result.buf.reg[6] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
+ result.buf.reg[7] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 8>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 8>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
// 4x1 := 4x1 * 1x1
template <>
struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {