aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/arch
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch')
-rw-r--r--Eigen/src/Core/arch/AVX/Complex.h215
-rw-r--r--Eigen/src/Core/arch/AVX/MathFunctions.h463
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h1271
-rw-r--r--Eigen/src/Core/arch/AVX/TypeCasting.h66
-rw-r--r--Eigen/src/Core/arch/AVX512/Complex.h422
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h448
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h1989
-rw-r--r--Eigen/src/Core/arch/AVX512/TypeCasting.h89
-rw-r--r--Eigen/src/Core/arch/AltiVec/Complex.h352
-rw-r--r--Eigen/src/Core/arch/AltiVec/MathFunctions.h270
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProduct.h2937
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h221
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h629
-rwxr-xr-xEigen/src/Core/arch/AltiVec/PacketMath.h2384
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h315
-rw-r--r--Eigen/src/Core/arch/CUDA/Half.h636
-rw-r--r--Eigen/src/Core/arch/CUDA/PacketMath.h333
-rw-r--r--Eigen/src/Core/arch/CUDA/PacketMathHalf.h1123
-rw-r--r--Eigen/src/Core/arch/CUDA/TypeCasting.h212
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h700
-rw-r--r--Eigen/src/Core/arch/Default/ConjHelper.h117
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h1649
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h110
-rw-r--r--Eigen/src/Core/arch/Default/Half.h942
-rw-r--r--Eigen/src/Core/arch/Default/Settings.h2
-rw-r--r--Eigen/src/Core/arch/Default/TypeCasting.h120
-rw-r--r--Eigen/src/Core/arch/GPU/MathFunctions.h (renamed from Eigen/src/Core/arch/CUDA/MathFunctions.h)20
-rw-r--r--Eigen/src/Core/arch/GPU/PacketMath.h1685
-rw-r--r--Eigen/src/Core/arch/GPU/TypeCasting.h80
-rw-r--r--Eigen/src/Core/arch/HIP/hcc/math_constants.h23
-rw-r--r--Eigen/src/Core/arch/MSA/Complex.h648
-rw-r--r--Eigen/src/Core/arch/MSA/MathFunctions.h387
-rw-r--r--Eigen/src/Core/arch/MSA/PacketMath.h1233
-rw-r--r--Eigen/src/Core/arch/NEON/Complex.h530
-rw-r--r--Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h183
-rw-r--r--Eigen/src/Core/arch/NEON/MathFunctions.h124
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h4626
-rw-r--r--Eigen/src/Core/arch/NEON/TypeCasting.h1419
-rw-r--r--Eigen/src/Core/arch/SSE/Complex.h278
-rw-r--r--Eigen/src/Core/arch/SSE/MathFunctions.h493
-rwxr-xr-xEigen/src/Core/arch/SSE/PacketMath.h1196
-rw-r--r--Eigen/src/Core/arch/SSE/TypeCasting.h93
-rw-r--r--Eigen/src/Core/arch/SVE/MathFunctions.h44
-rw-r--r--Eigen/src/Core/arch/SVE/PacketMath.h752
-rw-r--r--Eigen/src/Core/arch/SVE/TypeCasting.h49
-rw-r--r--Eigen/src/Core/arch/SYCL/InteropHeaders.h232
-rw-r--r--Eigen/src/Core/arch/SYCL/MathFunctions.h301
-rw-r--r--Eigen/src/Core/arch/SYCL/PacketMath.h670
-rw-r--r--Eigen/src/Core/arch/SYCL/SyclMemoryModel.h694
-rw-r--r--Eigen/src/Core/arch/SYCL/TypeCasting.h85
-rw-r--r--Eigen/src/Core/arch/ZVector/Complex.h392
-rw-r--r--Eigen/src/Core/arch/ZVector/MathFunctions.h106
-rwxr-xr-xEigen/src/Core/arch/ZVector/PacketMath.h883
53 files changed, 28461 insertions, 6780 deletions
diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h
index 99439c8aa..ab7bd6c65 100644
--- a/Eigen/src/Core/arch/AVX/Complex.h
+++ b/Eigen/src/Core/arch/AVX/Complex.h
@@ -22,6 +22,7 @@ struct Packet4cf
__m256 v;
};
+#ifndef EIGEN_VECTORIZE_AVX512
template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet4cf type;
@@ -37,6 +38,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -44,8 +46,20 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasSetLinear = 0
};
};
+#endif
-template<> struct unpacket_traits<Packet4cf> { typedef std::complex<float> type; enum {size=4, alignment=Aligned32}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet4cf> {
+ typedef std::complex<float> type;
+ typedef Packet2cf half;
+ typedef Packet8f as_real;
+ enum {
+ size=4,
+ alignment=Aligned32,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); }
@@ -67,10 +81,17 @@ template<> EIGEN_STRONG_INLINE Packet4cf pmul<Packet4cf>(const Packet4cf& a, con
return Packet4cf(result);
}
+template <>
+EIGEN_STRONG_INLINE Packet4cf pcmp_eq(const Packet4cf& a, const Packet4cf& b) {
+ __m256 eq = _mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ);
+ return Packet4cf(_mm256_and_ps(eq, _mm256_permute_ps(eq, 0xb1)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cf ptrue<Packet4cf>(const Packet4cf& a) { return Packet4cf(ptrue(Packet8f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet4cf pand <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_and_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf por <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_or_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf pxor <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_xor_ps(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf pload <Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cf(pload<Packet8f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet4cf ploadu<Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cf(ploadu<Packet8f>(&numext::real_ref(*from))); }
@@ -140,87 +161,13 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet4cf>(const Packe
Packet2cf(_mm256_extractf128_ps(a.v,1))));
}
-template<> EIGEN_STRONG_INLINE Packet4cf preduxp<Packet4cf>(const Packet4cf* vecs)
-{
- Packet8f t0 = _mm256_shuffle_ps(vecs[0].v, vecs[0].v, _MM_SHUFFLE(3, 1, 2 ,0));
- Packet8f t1 = _mm256_shuffle_ps(vecs[1].v, vecs[1].v, _MM_SHUFFLE(3, 1, 2 ,0));
- t0 = _mm256_hadd_ps(t0,t1);
- Packet8f t2 = _mm256_shuffle_ps(vecs[2].v, vecs[2].v, _MM_SHUFFLE(3, 1, 2 ,0));
- Packet8f t3 = _mm256_shuffle_ps(vecs[3].v, vecs[3].v, _MM_SHUFFLE(3, 1, 2 ,0));
- t2 = _mm256_hadd_ps(t2,t3);
-
- t1 = _mm256_permute2f128_ps(t0,t2, 0 + (2<<4));
- t3 = _mm256_permute2f128_ps(t0,t2, 1 + (3<<4));
-
- return Packet4cf(_mm256_add_ps(t1,t3));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet4cf>(const Packet4cf& a)
{
return predux_mul(pmul(Packet2cf(_mm256_extractf128_ps(a.v, 0)),
Packet2cf(_mm256_extractf128_ps(a.v, 1))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet4cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet4cf& first, const Packet4cf& second)
- {
- if (Offset==0) return;
- palign_impl<Offset*2,Packet8f>::run(first.v, second.v);
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet4cf, Packet4cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
-template<> struct conj_helper<Packet8f, Packet4cf, false,false>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet8f& x, const Packet4cf& y, const Packet4cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet8f& x, const Packet4cf& y) const
- { return Packet4cf(Eigen::internal::pmul(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet4cf, Packet8f, false,false>
-{
- EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& x, const Packet8f& y, const Packet4cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& x, const Packet8f& y) const
- { return Packet4cf(Eigen::internal::pmul(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f)
template<> EIGEN_STRONG_INLINE Packet4cf pdiv<Packet4cf>(const Packet4cf& a, const Packet4cf& b)
{
@@ -244,6 +191,7 @@ struct Packet2cd
__m256d v;
};
+#ifndef EIGEN_VECTORIZE_AVX512
template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet2cd type;
@@ -259,6 +207,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -266,8 +215,20 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasSetLinear = 0
};
};
+#endif
-template<> struct unpacket_traits<Packet2cd> { typedef std::complex<double> type; enum {size=2, alignment=Aligned32}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet2cd> {
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet4d as_real;
+ enum {
+ size=2,
+ alignment=Aligned32,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); }
@@ -288,10 +249,17 @@ template<> EIGEN_STRONG_INLINE Packet2cd pmul<Packet2cd>(const Packet2cd& a, con
return Packet2cd(_mm256_addsub_pd(even, odd));
}
+template <>
+EIGEN_STRONG_INLINE Packet2cd pcmp_eq(const Packet2cd& a, const Packet2cd& b) {
+ __m256d eq = _mm256_cmp_pd(a.v, b.v, _CMP_EQ_OQ);
+ return Packet2cd(pand(eq, _mm256_permute_pd(eq, 0x5)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cd ptrue<Packet2cd>(const Packet2cd& a) { return Packet2cd(ptrue(Packet4d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet2cd pand <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_and_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd por <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_or_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd pxor <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_xor_pd(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd pload <Packet2cd>(const std::complex<double>* from)
{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cd(pload<Packet4d>((const double*)from)); }
@@ -343,80 +311,13 @@ template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet2cd>(const Pack
Packet1cd(_mm256_extractf128_pd(a.v,1))));
}
-template<> EIGEN_STRONG_INLINE Packet2cd preduxp<Packet2cd>(const Packet2cd* vecs)
-{
- Packet4d t0 = _mm256_permute2f128_pd(vecs[0].v,vecs[1].v, 0 + (2<<4));
- Packet4d t1 = _mm256_permute2f128_pd(vecs[0].v,vecs[1].v, 1 + (3<<4));
-
- return Packet2cd(_mm256_add_pd(t0,t1));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet2cd>(const Packet2cd& a)
{
return predux(pmul(Packet1cd(_mm256_extractf128_pd(a.v,0)),
Packet1cd(_mm256_extractf128_pd(a.v,1))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cd& first, const Packet2cd& second)
- {
- if (Offset==0) return;
- palign_impl<Offset*2,Packet4d>::run(first.v, second.v);
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet2cd, Packet2cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
-template<> struct conj_helper<Packet4d, Packet2cd, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet4d& x, const Packet2cd& y, const Packet2cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet4d& x, const Packet2cd& y) const
- { return Packet2cd(Eigen::internal::pmul(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet2cd, Packet4d, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& x, const Packet4d& y, const Packet2cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& x, const Packet4d& y) const
- { return Packet2cd(Eigen::internal::pmul(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d)
template<> EIGEN_STRONG_INLINE Packet2cd pdiv<Packet2cd>(const Packet2cd& a, const Packet2cd& b)
{
@@ -456,24 +357,12 @@ ptranspose(PacketBlock<Packet2cd,2>& kernel) {
kernel.packet[0].v = tmp;
}
-template<> EIGEN_STRONG_INLINE Packet4cf pinsertfirst(const Packet4cf& a, std::complex<float> b)
-{
- return Packet4cf(_mm256_blend_ps(a.v,pset1<Packet4cf>(b).v,1|2));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2cd pinsertfirst(const Packet2cd& a, std::complex<double> b)
-{
- return Packet2cd(_mm256_blend_pd(a.v,pset1<Packet2cd>(b).v,1|2));
+template<> EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) {
+ return psqrt_complex<Packet2cd>(a);
}
-template<> EIGEN_STRONG_INLINE Packet4cf pinsertlast(const Packet4cf& a, std::complex<float> b)
-{
- return Packet4cf(_mm256_blend_ps(a.v,pset1<Packet4cf>(b).v,(1<<7)|(1<<6)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2cd pinsertlast(const Packet2cd& a, std::complex<double> b)
-{
- return Packet2cd(_mm256_blend_pd(a.v,pset1<Packet2cd>(b).v,(1<<3)|(1<<2)));
+template<> EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) {
+ return psqrt_complex<Packet4cf>(a);
}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index 6af67ce2d..67041c812 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_MATH_FUNCTIONS_AVX_H
#define EIGEN_MATH_FUNCTIONS_AVX_H
-/* The sin, cos, exp, and log functions of this file are loosely derived from
+/* The sin and cos functions of this file are loosely derived from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
@@ -18,187 +18,50 @@ namespace Eigen {
namespace internal {
-inline Packet8i pshiftleft(Packet8i v, int n)
-{
-#ifdef EIGEN_VECTORIZE_AVX2
- return _mm256_slli_epi32(v, n);
-#else
- __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(v, 0), n);
- __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(v, 1), n);
- return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
-#endif
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
+psin<Packet8f>(const Packet8f& _x) {
+ return psin_float(_x);
}
-inline Packet8f pshiftright(Packet8f v, int n)
-{
-#ifdef EIGEN_VECTORIZE_AVX2
- return _mm256_cvtepi32_ps(_mm256_srli_epi32(_mm256_castps_si256(v), n));
-#else
- __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 0), n);
- __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 1), n);
- return _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1));
-#endif
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
+pcos<Packet8f>(const Packet8f& _x) {
+ return pcos_float(_x);
}
-// Sine function
-// Computes sin(x) by wrapping x to the interval [-Pi/4,3*Pi/4] and
-// evaluating interpolants in [-Pi/4,Pi/4] or [Pi/4,3*Pi/4]. The interpolants
-// are (anti-)symmetric and thus have only odd/even coefficients
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-psin<Packet8f>(const Packet8f& _x) {
- Packet8f x = _x;
+plog<Packet8f>(const Packet8f& _x) {
+ return plog_float(_x);
+}
- // Some useful values.
- _EIGEN_DECLARE_CONST_Packet8i(one, 1);
- _EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(two, 2.0f);
- _EIGEN_DECLARE_CONST_Packet8f(one_over_four, 0.25f);
- _EIGEN_DECLARE_CONST_Packet8f(one_over_pi, 3.183098861837907e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_first, -3.140625000000000e+00f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_second, -9.670257568359375e-04f);
- _EIGEN_DECLARE_CONST_Packet8f(neg_pi_third, -6.278329571784980e-07f);
- _EIGEN_DECLARE_CONST_Packet8f(four_over_pi, 1.273239544735163e+00f);
-
- // Map x from [-Pi/4,3*Pi/4] to z in [-1,3] and subtract the shifted period.
- Packet8f z = pmul(x, p8f_one_over_pi);
- Packet8f shift = _mm256_floor_ps(padd(z, p8f_one_over_four));
- x = pmadd(shift, p8f_neg_pi_first, x);
- x = pmadd(shift, p8f_neg_pi_second, x);
- x = pmadd(shift, p8f_neg_pi_third, x);
- z = pmul(x, p8f_four_over_pi);
-
- // Make a mask for the entries that need flipping, i.e. wherever the shift
- // is odd.
- Packet8i shift_ints = _mm256_cvtps_epi32(shift);
- Packet8i shift_isodd = _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(shift_ints), _mm256_castsi256_ps(p8i_one)));
- Packet8i sign_flip_mask = pshiftleft(shift_isodd, 31);
-
- // Create a mask for which interpolant to use, i.e. if z > 1, then the mask
- // is set to ones for that entry.
- Packet8f ival_mask = _mm256_cmp_ps(z, p8f_one, _CMP_GT_OQ);
-
- // Evaluate the polynomial for the interval [1,3] in z.
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_0, 9.999999724233232e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_2, -3.084242535619928e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_4, 1.584991525700324e-02f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_right_6, -3.188805084631342e-04f);
- Packet8f z_minus_two = psub(z, p8f_two);
- Packet8f z_minus_two2 = pmul(z_minus_two, z_minus_two);
- Packet8f right = pmadd(p8f_coeff_right_6, z_minus_two2, p8f_coeff_right_4);
- right = pmadd(right, z_minus_two2, p8f_coeff_right_2);
- right = pmadd(right, z_minus_two2, p8f_coeff_right_0);
-
- // Evaluate the polynomial for the interval [-1,1] in z.
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_1, 7.853981525427295e-01f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_3, -8.074536727092352e-02f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_5, 2.489871967827018e-03f);
- _EIGEN_DECLARE_CONST_Packet8f(coeff_left_7, -3.587725841214251e-05f);
- Packet8f z2 = pmul(z, z);
- Packet8f left = pmadd(p8f_coeff_left_7, z2, p8f_coeff_left_5);
- left = pmadd(left, z2, p8f_coeff_left_3);
- left = pmadd(left, z2, p8f_coeff_left_1);
- left = pmul(left, z);
-
- // Assemble the results, i.e. select the left and right polynomials.
- left = _mm256_andnot_ps(ival_mask, left);
- right = _mm256_and_ps(ival_mask, right);
- Packet8f res = _mm256_or_ps(left, right);
-
- // Flip the sign on the odd intervals and return the result.
- res = _mm256_xor_ps(res, _mm256_castsi256_ps(sign_flip_mask));
- return res;
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
+plog<Packet4d>(const Packet4d& _x) {
+ return plog_double(_x);
}
-// Natural logarithm
-// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
-// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
-// be easily approximated by a polynomial centered on m=1 for stability.
-// TODO(gonnet): Further reduce the interval allowing for lower-degree
-// polynomial interpolants -> ... -> profit!
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-plog<Packet8f>(const Packet8f& _x) {
- Packet8f x = _x;
- _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet8f(126f, 126.0f);
-
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
- // The smallest non denormalized float number.
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(minus_inf, 0xff800000);
-
- // Polynomial coefficients.
- _EIGEN_DECLARE_CONST_Packet8f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p1, -1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p3, -1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p4, +1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p5, -1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p6, +2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p7, -2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p8, +3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q2, 0.693359375f);
-
- Packet8f invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_NGE_UQ); // not greater equal is true if x is NaN
- Packet8f iszero_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_EQ_OQ);
-
- // Truncate input values to the minimum positive normal.
- x = pmax(x, p8f_min_norm_pos);
-
- Packet8f emm0 = pshiftright(x,23);
- Packet8f e = _mm256_sub_ps(emm0, p8f_126f);
-
- // Set the exponents to -1, i.e. x are in the range [0.5,1).
- x = _mm256_and_ps(x, p8f_inv_mant_mask);
- x = _mm256_or_ps(x, p8f_half);
-
- // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
- // and shift by -1. The values are then centered around 0, which improves
- // the stability of the polynomial evaluation.
- // if( x < SQRTHF ) {
- // e -= 1;
- // x = x + x - 1.0;
- // } else { x = x - 1.0; }
- Packet8f mask = _mm256_cmp_ps(x, p8f_cephes_SQRTHF, _CMP_LT_OQ);
- Packet8f tmp = _mm256_and_ps(x, mask);
- x = psub(x, p8f_1);
- e = psub(e, _mm256_and_ps(p8f_1, mask));
- x = padd(x, tmp);
-
- Packet8f x2 = pmul(x, x);
- Packet8f x3 = pmul(x2, x);
-
- // Evaluate the polynomial approximant of degree 8 in three parts, probably
- // to improve instruction-level parallelism.
- Packet8f y, y1, y2;
- y = pmadd(p8f_cephes_log_p0, x, p8f_cephes_log_p1);
- y1 = pmadd(p8f_cephes_log_p3, x, p8f_cephes_log_p4);
- y2 = pmadd(p8f_cephes_log_p6, x, p8f_cephes_log_p7);
- y = pmadd(y, x, p8f_cephes_log_p2);
- y1 = pmadd(y1, x, p8f_cephes_log_p5);
- y2 = pmadd(y2, x, p8f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- // Add the logarithm of the exponent back to the result of the interpolation.
- y1 = pmul(e, p8f_cephes_log_q1);
- tmp = pmul(x2, p8f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p8f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
-
- // Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF.
- return _mm256_or_ps(
- _mm256_andnot_ps(iszero_mask, _mm256_or_ps(x, invalid_mask)),
- _mm256_and_ps(iszero_mask, p8f_minus_inf));
+plog2<Packet8f>(const Packet8f& _x) {
+ return plog2_float(_x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
+plog2<Packet4d>(const Packet4d& _x) {
+ return plog2_double(_x);
+}
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f plog1p<Packet8f>(const Packet8f& _x) {
+ return generic_plog1p(_x);
+}
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f pexpm1<Packet8f>(const Packet8f& _x) {
+ return generic_expm1(_x);
}
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -207,149 +70,21 @@ plog<Packet8f>(const Packet8f& _x) {
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
pexp<Packet8f>(const Packet8f& _x) {
- _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet8f(127, 127.0f);
-
- _EIGEN_DECLARE_CONST_Packet8f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet8f(exp_lo, -88.3762626647949f);
-
- _EIGEN_DECLARE_CONST_Packet8f(cephes_LOG2EF, 1.44269504088896341f);
-
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p5, 5.0000001201E-1f);
-
- // Clamp x.
- Packet8f x = pmax(pmin(_x, p8f_exp_hi), p8f_exp_lo);
-
- // Express exp(x) as exp(m*ln(2) + r), start by extracting
- // m = floor(x/ln(2) + 0.5).
- Packet8f m = _mm256_floor_ps(pmadd(x, p8f_cephes_LOG2EF, p8f_half));
-
-// Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
-// subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
-// truncation errors. Note that we don't use the "pmadd" function here to
-// ensure that a precision-preserving FMA instruction is used.
-#ifdef EIGEN_VECTORIZE_FMA
- _EIGEN_DECLARE_CONST_Packet8f(nln2, -0.6931471805599453f);
- Packet8f r = _mm256_fmadd_ps(m, p8f_nln2, x);
-#else
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C2, -2.12194440e-4f);
- Packet8f r = psub(x, pmul(m, p8f_cephes_exp_C1));
- r = psub(r, pmul(m, p8f_cephes_exp_C2));
-#endif
-
- Packet8f r2 = pmul(r, r);
-
- // TODO(gonnet): Split into odd/even polynomials and try to exploit
- // instruction-level parallelism.
- Packet8f y = p8f_cephes_exp_p0;
- y = pmadd(y, r, p8f_cephes_exp_p1);
- y = pmadd(y, r, p8f_cephes_exp_p2);
- y = pmadd(y, r, p8f_cephes_exp_p3);
- y = pmadd(y, r, p8f_cephes_exp_p4);
- y = pmadd(y, r, p8f_cephes_exp_p5);
- y = pmadd(y, r2, r);
- y = padd(y, p8f_1);
-
- // Build emm0 = 2^m.
- Packet8i emm0 = _mm256_cvttps_epi32(padd(m, p8f_127));
- emm0 = pshiftleft(emm0, 23);
-
- // Return 2^m * exp(r).
- return pmax(pmul(y, _mm256_castsi256_ps(emm0)), _x);
+ return pexp_float(_x);
}
// Hyperbolic Tangent function.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-ptanh<Packet8f>(const Packet8f& x) {
- return internal::generic_fast_tanh_float(x);
+ptanh<Packet8f>(const Packet8f& _x) {
+ return internal::generic_fast_tanh_float(_x);
}
+// Exponential function for doubles.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
pexp<Packet4d>(const Packet4d& _x) {
- Packet4d x = _x;
-
- _EIGEN_DECLARE_CONST_Packet4d(1, 1.0);
- _EIGEN_DECLARE_CONST_Packet4d(2, 2.0);
- _EIGEN_DECLARE_CONST_Packet4d(half, 0.5);
-
- _EIGEN_DECLARE_CONST_Packet4d(exp_hi, 709.437);
- _EIGEN_DECLARE_CONST_Packet4d(exp_lo, -709.436139303);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_LOG2EF, 1.4426950408889634073599);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p0, 1.26177193074810590878e-4);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p1, 3.02994407707441961300e-2);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_p2, 9.99999999999999999910e-1);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q0, 3.00198505138664455042e-6);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q1, 2.52448340349684104192e-3);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q2, 2.27265548208155028766e-1);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_q3, 2.00000000000000000009e0);
-
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_C1, 0.693145751953125);
- _EIGEN_DECLARE_CONST_Packet4d(cephes_exp_C2, 1.42860682030941723212e-6);
- _EIGEN_DECLARE_CONST_Packet4i(1023, 1023);
-
- Packet4d tmp, fx;
-
- // clamp x
- x = pmax(pmin(x, p4d_exp_hi), p4d_exp_lo);
- // Express exp(x) as exp(g + n*log(2)).
- fx = pmadd(p4d_cephes_LOG2EF, x, p4d_half);
-
- // Get the integer modulus of log(2), i.e. the "n" described above.
- fx = _mm256_floor_pd(fx);
-
- // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
- // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
- // digits right.
- tmp = pmul(fx, p4d_cephes_exp_C1);
- Packet4d z = pmul(fx, p4d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet4d x2 = pmul(x, x);
-
- // Evaluate the numerator polynomial of the rational interpolant.
- Packet4d px = p4d_cephes_exp_p0;
- px = pmadd(px, x2, p4d_cephes_exp_p1);
- px = pmadd(px, x2, p4d_cephes_exp_p2);
- px = pmul(px, x);
-
- // Evaluate the denominator polynomial of the rational interpolant.
- Packet4d qx = p4d_cephes_exp_q0;
- qx = pmadd(qx, x2, p4d_cephes_exp_q1);
- qx = pmadd(qx, x2, p4d_cephes_exp_q2);
- qx = pmadd(qx, x2, p4d_cephes_exp_q3);
-
- // I don't really get this bit, copied from the SSE2 routines, so...
- // TODO(gonnet): Figure out what is going on here, perhaps find a better
- // rational interpolant?
- x = _mm256_div_pd(px, psub(qx, px));
- x = pmadd(p4d_2, x, p4d_1);
-
- // Build e=2^n by constructing the exponents in a 128-bit vector and
- // shifting them to where they belong in double-precision values.
- __m128i emm0 = _mm256_cvtpd_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_1023);
- emm0 = _mm_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0));
- __m128i lo = _mm_slli_epi64(emm0, 52);
- __m128i hi = _mm_slli_epi64(_mm_srli_epi64(emm0, 32), 52);
- __m256i e = _mm256_insertf128_si256(_mm256_setzero_si256(), lo, 0);
- e = _mm256_insertf128_si256(e, hi, 1);
-
- // Construct the result 2^n * exp(g) = e * x. The max is used to catch
- // non-finite values in the input.
- return pmax(pmul(x, _mm256_castsi256_pd(e)), _x);
+ return pexp_double(_x);
}
// Functions for sqrt.
@@ -362,37 +97,39 @@ pexp<Packet4d>(const Packet4d& _x) {
// For detail see here: http://www.beyond3d.com/content/articles/8/
#if EIGEN_FAST_MATH
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-psqrt<Packet8f>(const Packet8f& _x) {
- Packet8f half = pmul(_x, pset1<Packet8f>(.5f));
- Packet8f denormal_mask = _mm256_and_ps(
- _mm256_cmp_ps(_x, pset1<Packet8f>((std::numeric_limits<float>::min)()),
- _CMP_LT_OQ),
- _mm256_cmp_ps(_x, _mm256_setzero_ps(), _CMP_GE_OQ));
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ Packet8f minus_half_x = pmul(_x, pset1<Packet8f>(-0.5f));
+ Packet8f denormal_mask = pandnot(
+ pcmp_lt(_x, pset1<Packet8f>((std::numeric_limits<float>::min)())),
+ pcmp_lt(_x, pzero(_x)));
// Compute approximate reciprocal sqrt.
Packet8f x = _mm256_rsqrt_ps(_x);
// Do a single step of Newton's iteration.
- x = pmul(x, psub(pset1<Packet8f>(1.5f), pmul(half, pmul(x,x))));
+ x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1<Packet8f>(1.5f)));
// Flush results for denormals to zero.
- return _mm256_andnot_ps(denormal_mask, pmul(_x,x));
+ return pandnot(pmul(_x,x), denormal_mask);
}
+
#else
+
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f psqrt<Packet8f>(const Packet8f& x) {
- return _mm256_sqrt_ps(x);
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ return _mm256_sqrt_ps(_x);
}
+
#endif
+
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d psqrt<Packet4d>(const Packet4d& x) {
- return _mm256_sqrt_pd(x);
+Packet4d psqrt<Packet4d>(const Packet4d& _x) {
+ return _mm256_sqrt_pd(_x);
}
-#if EIGEN_FAST_MATH
+#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(nan, 0x7fc00000);
_EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f);
_EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f);
_EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000);
@@ -401,36 +138,88 @@ Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
// select only the inverse sqrt of positive normal inputs (denormals are
// flushed to zero and cause infs as well).
- Packet8f le_zero_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ);
- Packet8f x = _mm256_andnot_ps(le_zero_mask, _mm256_rsqrt_ps(_x));
-
- // Fill in NaNs and Infs for the negative/zero entries.
- Packet8f neg_mask = _mm256_cmp_ps(_x, _mm256_setzero_ps(), _CMP_LT_OQ);
- Packet8f zero_mask = _mm256_andnot_ps(neg_mask, le_zero_mask);
- Packet8f infs_and_nans = _mm256_or_ps(_mm256_and_ps(neg_mask, p8f_nan),
- _mm256_and_ps(zero_mask, p8f_inf));
-
- // Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8f_one_point_five));
-
- // Insert NaNs and Infs in all the right places.
- return _mm256_or_ps(x, infs_and_nans);
+ Packet8f lt_min_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ);
+ Packet8f inf_mask = _mm256_cmp_ps(_x, p8f_inf, _CMP_EQ_OQ);
+ Packet8f not_normal_finite_mask = _mm256_or_ps(lt_min_mask, inf_mask);
+
+ // Compute an approximate result using the rsqrt intrinsic.
+ Packet8f y_approx = _mm256_rsqrt_ps(_x);
+
+ // Do a single step of Newton-Raphson iteration to improve the approximation.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet8f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p8f_one_point_five));
+
+ // Select the result of the Newton-Raphson step for positive normal arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
+ // x is zero or a positive denormalized float (equivalent to flushing positive
+ // denormalized inputs to zero).
+ return pselect<Packet8f>(not_normal_finite_mask, y_approx, y_newton);
}
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f prsqrt<Packet8f>(const Packet8f& x) {
+Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
+ return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d prsqrt<Packet4d>(const Packet4d& x) {
+Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
- return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
+ return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
}
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog2)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
+F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) {
+ Packet8f fexponent;
+ const Packet8h out = float2half(pfrexp<Packet8f>(half2float(a), fexponent));
+ exponent = float2half(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) {
+ return float2half(pldexp<Packet8f>(half2float(a), half2float(exponent)));
+}
+
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) {
+ Packet8f fexponent;
+ const Packet8bf out = F32ToBf16(pfrexp<Packet8f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) {
+ return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index 195d40fb4..7fc32fd71 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -18,11 +18,11 @@ namespace internal {
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
-#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
+#if !defined(EIGEN_VECTORIZE_AVX512) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS)
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
#endif
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
@@ -31,10 +31,14 @@ namespace internal {
typedef __m256 Packet8f;
typedef __m256i Packet8i;
typedef __m256d Packet4d;
+typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
+typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
template<> struct is_arithmetic<__m256> { enum { value = true }; };
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
@@ -58,21 +62,28 @@ template<> struct packet_traits<float> : default_packet_traits
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=8,
+ size = 8,
HasHalfPacket = 1,
- HasDiv = 1,
- HasSin = EIGEN_FAST_MATH,
- HasCos = 0,
- HasLog = 1,
- HasExp = 1,
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasTanh = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -85,14 +96,104 @@ template<> struct packet_traits<double> : default_packet_traits
size=4,
HasHalfPacket = 1,
+ HasCmp = 1,
HasDiv = 1,
+ HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
+ };
+};
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet8h type;
+ // There is no half-size packet for Packet8h.
+ typedef Packet8h half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasBessel = 1,
+ HasNdtri = 1
+ };
+};
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ // There is no half-size packet for current Packet8bf.
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasBessel = 1,
+ HasNdtri = 1
};
};
#endif
@@ -113,14 +214,45 @@ template<> struct packet_traits<int> : default_packet_traits
};
*/
-template<> struct unpacket_traits<Packet8f> { typedef float type; typedef Packet4f half; enum {size=8, alignment=Aligned32}; };
-template<> struct unpacket_traits<Packet4d> { typedef double type; typedef Packet2d half; enum {size=4, alignment=Aligned32}; };
-template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32}; };
+template<> struct unpacket_traits<Packet8f> {
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet8i integer_packet;
+ typedef uint8_t mask_t;
+ enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true};
+};
+template<> struct unpacket_traits<Packet4d> {
+ typedef double type;
+ typedef Packet2d half;
+ enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
+template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; };
+
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 16x16 to 8x16.
+EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
+ return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
+ _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+}
+
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); }
+template<> EIGEN_STRONG_INLINE Packet8f pset1frombits<Packet8f>(unsigned int from) { return _mm256_castsi256_ps(pset1<Packet8i>(from)); }
+template<> EIGEN_STRONG_INLINE Packet4d pset1frombits<Packet4d>(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); }
+
+
+template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return _mm256_castsi256_ps(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); }
+template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return _mm256_castsi256_pd(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); }
+
template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { return _mm256_broadcast_ss(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
@@ -129,9 +261,27 @@ template<> EIGEN_STRONG_INLINE Packet4d plset<Packet4d>(const double& a) { retur
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_add_epi32(a,b);
+#else
+ __m128i lo = _mm_add_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_add_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f psub<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d psub<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i psub<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_sub_epi32(a,b);
+#else
+ __m128i lo = _mm_sub_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_sub_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pnegate(const Packet8f& a)
{
@@ -148,7 +298,15 @@ template<> EIGEN_STRONG_INLINE Packet8i pconj(const Packet8i& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8f pmul<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pmul<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); }
-
+template<> EIGEN_STRONG_INLINE Packet8i pmul<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_mullo_epi32(a,b);
+#else
+ const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); }
@@ -157,13 +315,14 @@ template<> EIGEN_STRONG_INLINE Packet8i pdiv<Packet8i>(const Packet8i& /*a*/, co
return pset1<Packet8i>(0);
}
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
-#if ( EIGEN_COMP_GNUC_STRICT || (EIGEN_COMP_CLANG && (EIGEN_COMP_CLANG<308)) )
- // clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers,
- // and gcc stupidly generates a vfmadd132ps instruction,
- // so let's enforce it to generate a vfmadd231ps instruction since the most common use case is to accumulate
- // the result of the product.
+#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
+ // Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers,
+ // and even register spilling with clang>=6.0 (bug 1637).
+ // Gcc stupidly generates a vfmadd132ps instruction.
+ // So let's enforce it to generate a vfmadd231ps instruction since the most common use
+ // case is to accumulate the result of the product.
Packet8f res = c;
__asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
return res;
@@ -172,7 +331,7 @@ template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f&
#endif
}
template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
-#if ( EIGEN_COMP_GNUC_STRICT || (EIGEN_COMP_CLANG && (EIGEN_COMP_CLANG<308)) )
+#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
// see above
Packet4d res = c;
__asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
@@ -183,14 +342,112 @@ template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d&
}
#endif
-template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_min_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pmin<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_min_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
+
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
+
+
+template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_cmpeq_epi32(a,b);
+#else
+ __m128i lo = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may flip
+ // the argument order in calls to _mm_min_ps/_mm_max_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ Packet8f res;
+ asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::min.
+ return _mm256_min_ps(b,a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4d pmin<Packet4d>(const Packet4d& a, const Packet4d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet4d res;
+ asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::min.
+ return _mm256_min_pd(b,a);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet8f res;
+ asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::max.
+ return _mm256_max_ps(b,a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const Packet4d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // See pmin above
+ Packet4d res;
+ asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ return res;
+#else
+ // Arguments are swapped to match NaN propagation behavior of std::max.
+ return _mm256_max_pd(b,a);
+#endif
+}
-template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_max_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_max_pd(a,b); }
+// Add specializations for min/max with prescribed NaN progation.
+template<>
+EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmin<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmax<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmax<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmin<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmin<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet4d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8f pmax<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet8f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4d pmax<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet4d>);
+}
-template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
-template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet8f print<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet4d print<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }
@@ -198,17 +455,124 @@ template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { ret
template<> EIGEN_STRONG_INLINE Packet8f pfloor<Packet8f>(const Packet8f& a) { return _mm256_floor_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pfloor<Packet4d>(const Packet4d& a) { return _mm256_floor_pd(a); }
+
+template<> EIGEN_STRONG_INLINE Packet8i ptrue<Packet8i>(const Packet8i& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqd has lower latency than the more general vcmpps
+ return _mm256_cmpeq_epi32(a,a);
+#else
+ const __m256 b = _mm256_castsi256_ps(a);
+ return _mm256_castps_si256(_mm256_cmp_ps(b,b,_CMP_TRUE_UQ));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f ptrue<Packet8f>(const Packet8f& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqd has lower latency than the more general vcmpps
+ const __m256i b = _mm256_castps_si256(a);
+ return _mm256_castsi256_ps(_mm256_cmpeq_epi32(b,b));
+#else
+ return _mm256_cmp_ps(a,a,_CMP_TRUE_UQ);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet4d ptrue<Packet4d>(const Packet4d& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ // vpcmpeqq has lower latency than the more general vcmppd
+ const __m256i b = _mm256_castpd_si256(a);
+ return _mm256_castsi256_pd(_mm256_cmpeq_epi64(b,b));
+#else
+ return _mm256_cmp_pd(a,a,_CMP_TRUE_UQ);
+#endif
+}
+
template<> EIGEN_STRONG_INLINE Packet8f pand<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_and_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pand<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_and_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i pand<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_and_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f por<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_or_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d por<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_or_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i por<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_or_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pxor<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_xor_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pxor<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_xor_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8i pxor<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_xor_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
-template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(b,a); }
+template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(b,a); }
+template<> EIGEN_STRONG_INLINE Packet8i pandnot<Packet8i>(const Packet8i& a, const Packet8i& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_andnot_si256(b,a);
+#else
+ return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a)));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a)
+{
+ const Packet8f mask = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet8f prev0dot5 = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
+ return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a)
+{
+ const Packet4d mask = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet4d prev0dot5 = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
+ return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pselect<Packet8f>(const Packet8f& mask, const Packet8f& a, const Packet8f& b)
+{ return _mm256_blendv_ps(b,a,mask); }
+template<> EIGEN_STRONG_INLINE Packet4d pselect<Packet4d>(const Packet4d& mask, const Packet4d& a, const Packet4d& b)
+{ return _mm256_blendv_pd(b,a,mask); }
+
+template<int N> EIGEN_STRONG_INLINE Packet8i parithmetic_shift_right(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_srai_epi32(a, N);
+#else
+ __m128i lo = _mm_srai_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_srai_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_right(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_srli_epi32(a, N);
+#else
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_left(Packet8i a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_slli_epi32(a, N);
+#else
+ __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(a, 0), N);
+ __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(a, 1), N);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pload<Packet8f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload<Packet4d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_pd(from); }
@@ -218,6 +582,14 @@ template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from) { EI
template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
+ Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
+ const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
+ mask = por<Packet8i>(mask, bit_mask);
+ mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
+}
+
// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3}
template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from)
{
@@ -225,7 +597,7 @@ template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from)
// Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from));
// tmp = _mm256_insertf128_ps(tmp, _mm_movehl_ps(_mm256_castps256_ps128(tmp),_mm256_castps256_ps128(tmp)), 1);
// return _mm256_unpacklo_ps(tmp,tmp);
-
+
// _mm256_insertf128_ps is very slow on Haswell, thus:
Packet8f tmp = _mm256_broadcast_ps((const __m128*)(const void*)from);
// mimic an "inplace" permutation of the lower 128bits using a blend
@@ -255,6 +627,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f&
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
+ Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
+ const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
+ mask = por<Packet8i>(mask, bit_mask);
+ mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from);
+}
+
// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available
// NOTE: for the record the following seems to be slower: return _mm256_i32gather_ps(from, _mm256_set1_epi32(stride), 4);
template<> EIGEN_DEVICE_FUNC inline Packet8f pgather<float, Packet8f>(const float* from, Index stride)
@@ -308,9 +688,9 @@ template<> EIGEN_STRONG_INLINE void pstore1<Packet8i>(int* to, const int& a)
}
#ifndef EIGEN_VECTORIZE_AVX512
-template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
#endif
template<> EIGEN_STRONG_INLINE float pfirst<Packet8f>(const Packet8f& a) {
@@ -333,9 +713,12 @@ template<> EIGEN_STRONG_INLINE Packet4d preverse(const Packet4d& a)
{
__m256d tmp = _mm256_shuffle_pd(a,a,5);
return _mm256_permute2f128_pd(tmp, tmp, 1);
-
+ #if 0
+ // This version is unlikely to be faster as _mm256_shuffle_ps and _mm256_permute_pd
+ // exhibit the same latency/throughput, but it is here for future reference/benchmarking...
__m256d swap_halves = _mm256_permute2f128_pd(a,a,1);
return _mm256_permute_pd(swap_halves,5);
+ #endif
}
// pabs should be ok
@@ -350,47 +733,66 @@ template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a)
return _mm256_and_pd(a,mask);
}
-// preduxp should be ok
-// FIXME: why is this ok? why isn't the simply implementation working as expected?
-template<> EIGEN_STRONG_INLINE Packet8f preduxp<Packet8f>(const Packet8f* vecs)
-{
- __m256 hsum1 = _mm256_hadd_ps(vecs[0], vecs[1]);
- __m256 hsum2 = _mm256_hadd_ps(vecs[2], vecs[3]);
- __m256 hsum3 = _mm256_hadd_ps(vecs[4], vecs[5]);
- __m256 hsum4 = _mm256_hadd_ps(vecs[6], vecs[7]);
-
- __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
+template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
+ return pfrexp_generic(a,exponent);
+}
- __m256 sum1 = _mm256_add_ps(perm1, hsum5);
- __m256 sum2 = _mm256_add_ps(perm2, hsum6);
- __m256 sum3 = _mm256_add_ps(perm3, hsum7);
- __m256 sum4 = _mm256_add_ps(perm4, hsum8);
+// Extract exponent without existence of Packet4l.
+template<>
+EIGEN_STRONG_INLINE
+Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) {
+ const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull));
+ __m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
+#ifdef EIGEN_VECTORIZE_AVX2
+ a_expo = _mm256_srli_epi64(a_expo, 52);
+ __m128i lo = _mm256_extractf128_si256(a_expo, 0);
+ __m128i hi = _mm256_extractf128_si256(a_expo, 1);
+#else
+ __m128i lo = _mm256_extractf128_si256(a_expo, 0);
+ __m128i hi = _mm256_extractf128_si256(a_expo, 1);
+ lo = _mm_srli_epi64(lo, 52);
+ hi = _mm_srli_epi64(hi, 52);
+#endif
+ Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3));
+ Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
+ Packet4d exponent = _mm256_insertf128_pd(_mm256_setzero_pd(), exponent_lo, 0);
+ exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1);
+ return exponent;
+}
- __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
- __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
- return final;
+template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
+ return pfrexp_generic(a, exponent);
}
-template<> EIGEN_STRONG_INLINE Packet4d preduxp<Packet4d>(const Packet4d* vecs)
-{
- Packet4d tmp0, tmp1;
- tmp0 = _mm256_hadd_pd(vecs[0], vecs[1]);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
-
- tmp1 = _mm256_hadd_pd(vecs[2], vecs[3]);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
+template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
+ return pldexp_generic(a, exponent);
+}
- return _mm256_blend_pd(tmp0, tmp1, 0xC);
+template<> EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet4d max_exponent = pset1<Packet4d>(2099.0);
+ const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+
+ // Split 2^e into four factors and multiply.
+ const Packet4i bias = pset1<Packet4i>(1023);
+ Packet4i b = parithmetic_shift_right<2>(e); // floor(e/4)
+
+ // 2^b
+ Packet4i hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
+ Packet4i lo = _mm_slli_epi64(hi, 52);
+ hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
+ Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
+ Packet4d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+
+ // 2^(e - 3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
+ lo = _mm_slli_epi64(hi, 52);
+ hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
+ c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
+ out = pmul(out, c); // a * 2^e
+ return out;
}
template<> EIGEN_STRONG_INLINE float predux<Packet8f>(const Packet8f& a)
@@ -402,7 +804,7 @@ template<> EIGEN_STRONG_INLINE double predux<Packet4d>(const Packet4d& a)
return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a),_mm256_extractf128_pd(a,1))));
}
-template<> EIGEN_STRONG_INLINE Packet4f predux_downto4<Packet8f>(const Packet8f& a)
+template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4<Packet8f>(const Packet8f& a)
{
return _mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1));
}
@@ -446,93 +848,16 @@ template<> EIGEN_STRONG_INLINE double predux_max<Packet4d>(const Packet4d& a)
return pfirst(_mm256_max_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1)));
}
+// not needed yet
+// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet8f& x)
+// {
+// return _mm256_movemask_ps(x)==0xFF;
+// }
-template<int Offset>
-struct palign_impl<Offset,Packet8f>
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet8f& x)
{
- static EIGEN_STRONG_INLINE void run(Packet8f& first, const Packet8f& second)
- {
- if (Offset==1)
- {
- first = _mm256_blend_ps(first, second, 1);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(0,3,2,1));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0x88);
- }
- else if (Offset==2)
- {
- first = _mm256_blend_ps(first, second, 3);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(1,0,3,2));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0xcc);
- }
- else if (Offset==3)
- {
- first = _mm256_blend_ps(first, second, 7);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(2,1,0,3));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_blend_ps(tmp1, tmp2, 0xee);
- }
- else if (Offset==4)
- {
- first = _mm256_blend_ps(first, second, 15);
- Packet8f tmp1 = _mm256_permute_ps (first, _MM_SHUFFLE(3,2,1,0));
- Packet8f tmp2 = _mm256_permute2f128_ps (tmp1, tmp1, 1);
- first = _mm256_permute_ps(tmp2, _MM_SHUFFLE(3,2,1,0));
- }
- else if (Offset==5)
- {
- first = _mm256_blend_ps(first, second, 31);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(0,3,2,1));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0x88);
- }
- else if (Offset==6)
- {
- first = _mm256_blend_ps(first, second, 63);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(1,0,3,2));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0xcc);
- }
- else if (Offset==7)
- {
- first = _mm256_blend_ps(first, second, 127);
- first = _mm256_permute2f128_ps(first, first, 1);
- Packet8f tmp = _mm256_permute_ps (first, _MM_SHUFFLE(2,1,0,3));
- first = _mm256_permute2f128_ps(tmp, tmp, 1);
- first = _mm256_blend_ps(tmp, first, 0xee);
- }
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4d>
-{
- static EIGEN_STRONG_INLINE void run(Packet4d& first, const Packet4d& second)
- {
- if (Offset==1)
- {
- first = _mm256_blend_pd(first, second, 1);
- __m256d tmp = _mm256_permute_pd(first, 5);
- first = _mm256_permute2f128_pd(tmp, tmp, 1);
- first = _mm256_blend_pd(tmp, first, 0xA);
- }
- else if (Offset==2)
- {
- first = _mm256_blend_pd(first, second, 3);
- first = _mm256_permute2f128_pd(first, first, 1);
- }
- else if (Offset==3)
- {
- first = _mm256_blend_pd(first, second, 7);
- __m256d tmp = _mm256_permute_pd(first, 5);
- first = _mm256_permute2f128_pd(tmp, tmp, 1);
- first = _mm256_blend_pd(tmp, first, 5);
- }
- }
-};
+ return _mm256_movemask_ps(x)!=0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8f,8>& kernel) {
@@ -606,24 +931,640 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons
return _mm256_blendv_pd(thenPacket, elsePacket, false_mask);
}
-template<> EIGEN_STRONG_INLINE Packet8f pinsertfirst(const Packet8f& a, float b)
+// Packet math for Eigen::half
+
+template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
+
+template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
+ return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
+ return numext::bit_cast<Eigen::half>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h
+ploaddup<Packet8h>(const Eigen::half* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]);
+ const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]);
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h
+ploadquad<Packet8h>(const Eigen::half* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) {
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
+}
+
+EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm256_cvtph_ps(a);
+#else
+ EIGEN_ALIGN32 Eigen::half aux[8];
+ pstore(aux, a);
+ float f0(aux[0]);
+ float f1(aux[1]);
+ float f2(aux[2]);
+ float f3(aux[3]);
+ float f4(aux[4]);
+ float f5(aux[5]);
+ float f6(aux[6]);
+ float f7(aux[7]);
+
+ return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0);
+#endif
+}
+
+EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
+#else
+ EIGEN_ALIGN32 float aux[8];
+ pstore(aux, a);
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0]));
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1]));
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2]));
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3]));
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4]));
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5]));
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6]));
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7]));
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmin<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a,
+ const Packet8h& b) {
+ return float2half(pmax<Packet8f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
+ return float2half(plset<Packet8f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) {
+ // in some cases Packet4i is a wrapper around __m128i, so we either need to
+ // cast to Packet4i to directly call the intrinsics as below:
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a,const Packet8h& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a,const Packet8h& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
+ return float2half(pround<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
+ return float2half(print<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
+ return float2half(pceil<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
+ return float2half(pfloor<Packet8f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) {
+ return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) {
+ Packet8h sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = padd(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = psub(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = pmul(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ Packet8f af = half2float(a);
+ Packet8f bf = half2float(b);
+ Packet8f rf = pdiv(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
{
- return _mm256_blend_ps(a,pset1<Packet8f>(b),1);
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]);
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]);
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]);
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]);
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]);
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]);
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]);
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]);
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
}
-template<> EIGEN_STRONG_INLINE Packet4d pinsertfirst(const Packet4d& a, double b)
+template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride)
{
- return _mm256_blend_pd(a,pset1<Packet4d>(b),1);
+ EIGEN_ALIGN32 Eigen::half aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_max<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_min<Packet8f>(af);
+ return Eigen::half(reduced);
}
-template<> EIGEN_STRONG_INLINE Packet8f pinsertlast(const Packet8f& a, float b)
+template<> EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8h>(const Packet8h& a) {
+ Packet8f af = half2float(a);
+ float reduced = predux_mul<Packet8f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a)
{
- return _mm256_blend_ps(a,pset1<Packet8f>(b),(1<<7));
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8h,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+
+ kernel.packet[0] = a0b0c0d0e0f0g0h0;
+ kernel.packet[1] = a1b1c1d1e1f1g1h1;
+ kernel.packet[2] = a2b2c2d2e2f2g2h2;
+ kernel.packet[3] = a3b3c3d3e3f3g3h3;
+ kernel.packet[4] = a4b4c4d4e4f4g4h4;
+ kernel.packet[5] = a5b5c5d5e5f5g5h5;
+ kernel.packet[6] = a6b6c6d6e6f6g6h6;
+ kernel.packet[7] = a7b7c7d7e7f7g7h7;
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8h,4>& kernel) {
+ EIGEN_ALIGN32 Eigen::half in[4][8];
+ pstore<Eigen::half>(in[0], kernel.packet[0]);
+ pstore<Eigen::half>(in[1], kernel.packet[1]);
+ pstore<Eigen::half>(in[2], kernel.packet[2]);
+ pstore<Eigen::half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN32 Eigen::half out[4][8];
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][2*i];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+4] = in[j][2*i+1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet8h>(out[0]);
+ kernel.packet[1] = pload<Packet8h>(out[1]);
+ kernel.packet[2] = pload<Packet8h>(out[2]);
+ kernel.packet[3] = pload<Packet8h>(out[3]);
+}
+
+// BFloat16 implementation.
+
+EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ __m256i extend = _mm256_cvtepu16_epi32(a);
+ return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16));
+#else
+ __m128i lo = _mm_cvtepu16_epi32(a);
+ __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
+ __m128i lo_shift = _mm_slli_epi32(lo, 16);
+ __m128i hi_shift = _mm_slli_epi32(hi, 16);
+ return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1));
+#endif
}
-template<> EIGEN_STRONG_INLINE Packet4d pinsertlast(const Packet4d& a, double b)
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
+ Packet8bf r;
+
+ __m256i input = _mm256_castps_si256(a);
+
+#ifdef EIGEN_VECTORIZE_AVX2
+ // uint32_t lsb = (input >> 16);
+ __m256i t = _mm256_srli_epi32(input, 16);
+ // uint32_t lsb = lsb & 1;
+ t = _mm256_and_si256(t, _mm256_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm256_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm256_srli_epi32(t, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
+ __m256i nan = _mm256_set1_epi32(0x7fc0);
+ t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
+ // output = numext::bit_cast<uint16_t>(input);
+ return _mm_packus_epi32(_mm256_extractf128_si256(t, 0),
+ _mm256_extractf128_si256(t, 1));
+#else
+ // uint32_t lsb = (input >> 16);
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16);
+ // uint32_t lsb = lsb & 1;
+ lo = _mm_and_si128(lo, _mm_set1_epi32(1));
+ hi = _mm_and_si128(hi, _mm_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff));
+ hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0));
+ hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1));
+ // input = input >> 16;
+ lo = _mm_srli_epi32(lo, 16);
+ hi = _mm_srli_epi32(hi, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
+ __m128i nan = _mm_set1_epi32(0x7fc0);
+ lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
+ hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
+ // output = numext::bit_cast<uint16_t>(input);
+ return _mm_packus_epi32(lo, hi);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) {
+ return numext::bit_cast<bfloat16>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploaddup<Packet8bf>(const bfloat16* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]);
+ const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]);
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploadquad<Packet8bf>(const bfloat16* from) {
+ const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]);
+ const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]);
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_andnot_si128(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
+ return F32ToBf16(plset<Packet8f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a)
{
- return _mm256_blend_pd(a,pset1<Packet4d>(b),(1<<3));
+ return F32ToBf16(pround<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(print<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) {
+ Packet8bf sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]);
+ const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]);
+ const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]);
+ const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]);
+ const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]);
+ const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]);
+ const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]);
+ const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]);
+ return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ EIGEN_ALIGN32 bfloat16 aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+
+ __m128i ab_03 = _mm_unpacklo_epi16(a, b);
+ __m128i cd_03 = _mm_unpacklo_epi16(c, d);
+ __m128i ab_47 = _mm_unpackhi_epi16(a, b);
+ __m128i cd_47 = _mm_unpackhi_epi16(c, d);
+
+ kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03);
+ kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03);
+ kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47);
+ kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47);
}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index 83bfdc604..d507fb67b 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -35,15 +35,79 @@ struct type_casting_traits<int, float> {
};
+#ifndef EIGEN_VECTORIZE_AVX512
+
+template <>
+struct type_casting_traits<Eigen::half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+
+template <>
+struct type_casting_traits<float, Eigen::half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
- return _mm256_cvtps_epi32(a);
+ return _mm256_cvttps_epi32(a);
}
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i& a) {
return _mm256_cvtepi32_ps(a);
}
+template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
+ return _mm256_castps_si256(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Packet8i& a) {
+ return _mm256_castsi256_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
+ return half2float(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
+ return float2half(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
+ return F32ToBf16(a);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h
new file mode 100644
index 000000000..49c72b3f1
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/Complex.h
@@ -0,0 +1,422 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_COMPLEX_AVX512_H
+#define EIGEN_COMPLEX_AVX512_H
+
+namespace Eigen {
+
+namespace internal {
+
+//---------- float ----------
+struct Packet8cf
+{
+ EIGEN_STRONG_INLINE Packet8cf() {}
+ EIGEN_STRONG_INLINE explicit Packet8cf(const __m512& a) : v(a) {}
+ __m512 v;
+};
+
+template<> struct packet_traits<std::complex<float> > : default_packet_traits
+{
+ typedef Packet8cf type;
+ typedef Packet4cf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasSqrt = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasSetLinear = 0
+ };
+};
+
+template<> struct unpacket_traits<Packet8cf> {
+ typedef std::complex<float> type;
+ typedef Packet4cf half;
+ typedef Packet16f as_real;
+ enum {
+ size = 8,
+ alignment=unpacket_traits<Packet16f>::alignment,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet8cf ptrue<Packet8cf>(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); }
+template<> EIGEN_STRONG_INLINE Packet8cf padd<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_add_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet8cf psub<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_sub_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a)
+{
+ return Packet8cf(pnegate(a.v));
+}
+template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a)
+{
+ const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32(
+ 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,
+ 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000));
+ return Packet8cf(pxor(a.v,mask));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf pmul<Packet8cf>(const Packet8cf& a, const Packet8cf& b)
+{
+ __m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(a.v), _mm512_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1)));
+ return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf pand <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pand(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet8cf por <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(por(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet8cf pxor <Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pxor(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet8cf pandnot<Packet8cf>(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pandnot(a.v,b.v)); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8cf pcmp_eq(const Packet8cf& a, const Packet8cf& b) {
+ __m512 eq = pcmp_eq<Packet16f>(a.v, b.v);
+ return Packet8cf(pand(eq, _mm512_permute_ps(eq, 0xB1)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf pload <Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload<Packet16f>(&numext::real_ref(*from))); }
+template<> EIGEN_STRONG_INLINE Packet8cf ploadu<Packet8cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu<Packet16f>(&numext::real_ref(*from))); }
+
+
+template<> EIGEN_STRONG_INLINE Packet8cf pset1<Packet8cf>(const std::complex<float>& from)
+{
+ return Packet8cf(_mm512_castpd_ps(pload1<Packet8d>((const double*)(const void*)&from)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf ploaddup<Packet8cf>(const std::complex<float>* from)
+{
+ return Packet8cf( _mm512_castpd_ps( ploaddup<Packet8d>((const double*)(const void*)from )) );
+}
+template<> EIGEN_STRONG_INLINE Packet8cf ploadquad<Packet8cf>(const std::complex<float>* from)
+{
+ return Packet8cf( _mm512_castpd_ps( ploadquad<Packet8d>((const double*)(const void*)from )) );
+}
+
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float>* to, const Packet8cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to, const Packet8cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet8cf pgather<std::complex<float>, Packet8cf>(const std::complex<float>* from, Index stride)
+{
+ return Packet8cf(_mm512_castpd_ps(pgather<double,Packet8d>((const double*)(const void*)from, stride)));
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet8cf>(std::complex<float>* to, const Packet8cf& from, Index stride)
+{
+ pscatter((double*)(void*)to, _mm512_castps_pd(from.v), stride);
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet8cf>(const Packet8cf& a)
+{
+ return pfirst(Packet2cf(_mm512_castps512_ps128(a.v)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) {
+ return Packet8cf(_mm512_castsi512_ps(
+ _mm512_permutexvar_epi64( _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7),
+ _mm512_castps_si512(a.v))));
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet8cf>(const Packet8cf& a)
+{
+ return predux(padd(Packet4cf(extract256<0>(a.v)),
+ Packet4cf(extract256<1>(a.v))));
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet8cf>(const Packet8cf& a)
+{
+ return predux_mul(pmul(Packet4cf(extract256<0>(a.v)),
+ Packet4cf(extract256<1>(a.v))));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4<Packet8cf>(const Packet8cf& a) {
+ __m256 lane0 = extract256<0>(a.v);
+ __m256 lane1 = extract256<1>(a.v);
+ __m256 res = _mm256_add_ps(lane0, lane1);
+ return Packet4cf(res);
+}
+
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f)
+
+template<> EIGEN_STRONG_INLINE Packet8cf pdiv<Packet8cf>(const Packet8cf& a, const Packet8cf& b)
+{
+ Packet8cf num = pmul(a, pconj(b));
+ __m512 tmp = _mm512_mul_ps(b.v, b.v);
+ __m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1);
+ __m512 denom = _mm512_add_ps(tmp, tmp2);
+ return Packet8cf(_mm512_div_ps(num.v, denom));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip<Packet8cf>(const Packet8cf& x)
+{
+ return Packet8cf(_mm512_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1)));
+}
+
+//---------- double ----------
+struct Packet4cd
+{
+ EIGEN_STRONG_INLINE Packet4cd() {}
+ EIGEN_STRONG_INLINE explicit Packet4cd(const __m512d& a) : v(a) {}
+ __m512d v;
+};
+
+template<> struct packet_traits<std::complex<double> > : default_packet_traits
+{
+ typedef Packet4cd type;
+ typedef Packet2cd half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 0,
+ size = 4,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasSqrt = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasSetLinear = 0
+ };
+};
+
+template<> struct unpacket_traits<Packet4cd> {
+ typedef std::complex<double> type;
+ typedef Packet2cd half;
+ typedef Packet8d as_real;
+ enum {
+ size = 4,
+ alignment = unpacket_traits<Packet8d>::alignment,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4cd padd<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd psub<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_sub_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd pnegate(const Packet4cd& a) { return Packet4cd(pnegate(a.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd pconj(const Packet4cd& a)
+{
+ const __m512d mask = _mm512_castsi512_pd(
+ _mm512_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0,
+ 0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0));
+ return Packet4cd(pxor(a.v,mask));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd pmul<Packet4cd>(const Packet4cd& a, const Packet4cd& b)
+{
+ __m512d tmp1 = _mm512_shuffle_pd(a.v,a.v,0x0);
+ __m512d tmp2 = _mm512_shuffle_pd(a.v,a.v,0xFF);
+ __m512d tmp3 = _mm512_shuffle_pd(b.v,b.v,0x55);
+ __m512d odd = _mm512_mul_pd(tmp2, tmp3);
+ return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd ptrue<Packet4cd>(const Packet4cd& a) { return Packet4cd(ptrue(Packet8d(a.v))); }
+template<> EIGEN_STRONG_INLINE Packet4cd pand <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pand(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd por <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(por(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd pxor <Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pxor(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet4cd pandnot<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pandnot(a.v,b.v)); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4cd pcmp_eq(const Packet4cd& a, const Packet4cd& b) {
+ __m512d eq = pcmp_eq<Packet8d>(a.v, b.v);
+ return Packet4cd(pand(eq, _mm512_permute_pd(eq, 0x55)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd pload <Packet4cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload<Packet8d>((const double*)from)); }
+template<> EIGEN_STRONG_INLINE Packet4cd ploadu<Packet4cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cd(ploadu<Packet8d>((const double*)from)); }
+
+template<> EIGEN_STRONG_INLINE Packet4cd pset1<Packet4cd>(const std::complex<double>& from)
+{
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
+ return Packet4cd(_mm512_broadcast_f64x2(pset1<Packet1cd>(from).v));
+ #else
+ return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4( _mm_castpd_ps(pset1<Packet1cd>(from).v))));
+ #endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd ploaddup<Packet4cd>(const std::complex<double>* from) {
+ return Packet4cd(_mm512_insertf64x4(
+ _mm512_castpd256_pd512(ploaddup<Packet2cd>(from).v), ploaddup<Packet2cd>(from+1).v, 1));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet4cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet4cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet4cd pgather<std::complex<double>, Packet4cd>(const std::complex<double>* from, Index stride)
+{
+ return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512(
+ _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+0*stride).v), ploadu<Packet1cd>(from+1*stride).v,1)),
+ _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from+2*stride).v), ploadu<Packet1cd>(from+3*stride).v,1), 1));
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet4cd>(std::complex<double>* to, const Packet4cd& from, Index stride)
+{
+ __m512i fromi = _mm512_castpd_si512(from.v);
+ double* tod = (double*)(void*)to;
+ _mm_storeu_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) );
+ _mm_storeu_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) );
+ _mm_storeu_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) );
+ _mm_storeu_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) );
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet4cd>(const Packet4cd& a)
+{
+ __m128d low = extract128<0>(a.v);
+ EIGEN_ALIGN16 double res[2];
+ _mm_store_pd(res, low);
+ return std::complex<double>(res[0],res[1]);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd preverse(const Packet4cd& a) {
+ return Packet4cd(_mm512_shuffle_f64x2(a.v, a.v, (shuffle_mask<3,2,1,0>::mask)));
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet4cd>(const Packet4cd& a)
+{
+ return predux(padd(Packet2cd(_mm512_extractf64x4_pd(a.v,0)),
+ Packet2cd(_mm512_extractf64x4_pd(a.v,1))));
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet4cd>(const Packet4cd& a)
+{
+ return predux_mul(pmul(Packet2cd(_mm512_extractf64x4_pd(a.v,0)),
+ Packet2cd(_mm512_extractf64x4_pd(a.v,1))));
+}
+
+template<> struct conj_helper<Packet4cd, Packet4cd, false,true>
+{
+ EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
+ { return padd(pmul(x,y),c); }
+
+ EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
+ {
+ return internal::pmul(a, pconj(b));
+ }
+};
+
+template<> struct conj_helper<Packet4cd, Packet4cd, true,false>
+{
+ EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
+ { return padd(pmul(x,y),c); }
+
+ EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
+ {
+ return internal::pmul(pconj(a), b);
+ }
+};
+
+template<> struct conj_helper<Packet4cd, Packet4cd, true,true>
+{
+ EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const
+ { return padd(pmul(x,y),c); }
+
+ EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const
+ {
+ return pconj(internal::pmul(a, b));
+ }
+};
+
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d)
+
+template<> EIGEN_STRONG_INLINE Packet4cd pdiv<Packet4cd>(const Packet4cd& a, const Packet4cd& b)
+{
+ Packet4cd num = pmul(a, pconj(b));
+ __m512d tmp = _mm512_mul_pd(b.v, b.v);
+ __m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp);
+ return Packet4cd(_mm512_div_pd(num.v, denom));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x)
+{
+ return Packet4cd(_mm512_permute_pd(x.v,0x55));
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8cf,4>& kernel) {
+ PacketBlock<Packet8d,4> pb;
+
+ pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
+ pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
+ pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
+ pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
+ ptranspose(pb);
+ kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
+ kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
+ kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
+ kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8cf,8>& kernel) {
+ PacketBlock<Packet8d,8> pb;
+
+ pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
+ pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
+ pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
+ pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
+ pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
+ pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
+ pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
+ pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
+ ptranspose(pb);
+ kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
+ kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
+ kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
+ kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
+ kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
+ kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
+ kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
+ kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet4cd,4>& kernel) {
+ __m512d T0 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<0,1,0,1>::mask)); // [a0 a1 b0 b1]
+ __m512d T1 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<2,3,2,3>::mask)); // [a2 a3 b2 b3]
+ __m512d T2 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<0,1,0,1>::mask)); // [c0 c1 d0 d1]
+ __m512d T3 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<2,3,2,3>::mask)); // [c2 c3 d2 d3]
+
+ kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<1,3,1,3>::mask))); // [a3 b3 c3 d3]
+ kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<0,2,0,2>::mask))); // [a2 b2 c2 d2]
+ kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<1,3,1,3>::mask))); // [a1 b1 c1 d1]
+ kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask))); // [a0 b0 c0 d0]
+}
+
+template<> EIGEN_STRONG_INLINE Packet4cd psqrt<Packet4cd>(const Packet4cd& a) {
+ return psqrt_complex<Packet4cd>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8cf psqrt<Packet8cf>(const Packet8cf& a) {
+ return psqrt_complex<Packet8cf>(a);
+}
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_COMPLEX_AVX512_H
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 399be0ee4..6fd726d29 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -15,13 +15,13 @@ namespace Eigen {
namespace internal {
// Disable the code for older versions of gcc that don't support many of the required avx512 instrinsics.
-#if EIGEN_GNUC_AT_LEAST(5, 3)
+#if EIGEN_GNUC_AT_LEAST(5, 3) || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC >= 1923
#define _EIGEN_DECLARE_CONST_Packet16f(NAME, X) \
const Packet16f p16f_##NAME = pset1<Packet16f>(X)
#define _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(NAME, X) \
- const Packet16f p16f_##NAME = (__m512)pset1<Packet16i>(X)
+ const Packet16f p16f_##NAME = preinterpret<Packet16f,Packet16i>(pset1<Packet16i>(X))
#define _EIGEN_DECLARE_CONST_Packet8d(NAME, X) \
const Packet8d p8d_##NAME = pset1<Packet8d>(X)
@@ -29,100 +29,41 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
-// Natural logarithm
-// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
-// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
-// be easily approximated by a polynomial centered on m=1 for stability.
-#if defined(EIGEN_VECTORIZE_AVX512DQ)
+#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \
+ const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X)
+
+#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
+ const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
+
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
plog<Packet16f>(const Packet16f& _x) {
- Packet16f x = _x;
- _EIGEN_DECLARE_CONST_Packet16f(1, 1.0f);
- _EIGEN_DECLARE_CONST_Packet16f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet16f(126f, 126.0f);
-
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
- // The smallest non denormalized float number.
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(minus_inf, 0xff800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
-
- // Polynomial coefficients.
- _EIGEN_DECLARE_CONST_Packet16f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p1, -1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p3, -1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p4, +1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p5, -1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p6, +2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p7, -2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p8, +3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q2, 0.693359375f);
-
- // invalid_mask is set to true when x is NaN
- __mmask16 invalid_mask =
- _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_NGE_UQ);
- __mmask16 iszero_mask =
- _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_EQ_UQ);
-
- // Truncate input values to the minimum positive normal.
- x = pmax(x, p16f_min_norm_pos);
-
- // Extract the shifted exponents.
- Packet16f emm0 = _mm512_cvtepi32_ps(_mm512_srli_epi32((__m512i)x, 23));
- Packet16f e = _mm512_sub_ps(emm0, p16f_126f);
-
- // Set the exponents to -1, i.e. x are in the range [0.5,1).
- x = _mm512_and_ps(x, p16f_inv_mant_mask);
- x = _mm512_or_ps(x, p16f_half);
-
- // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
- // and shift by -1. The values are then centered around 0, which improves
- // the stability of the polynomial evaluation.
- // if( x < SQRTHF ) {
- // e -= 1;
- // x = x + x - 1.0;
- // } else { x = x - 1.0; }
- __mmask16 mask = _mm512_cmp_ps_mask(x, p16f_cephes_SQRTHF, _CMP_LT_OQ);
- Packet16f tmp = _mm512_mask_blend_ps(mask, x, _mm512_setzero_ps());
- x = psub(x, p16f_1);
- e = psub(e, _mm512_mask_blend_ps(mask, p16f_1, _mm512_setzero_ps()));
- x = padd(x, tmp);
-
- Packet16f x2 = pmul(x, x);
- Packet16f x3 = pmul(x2, x);
-
- // Evaluate the polynomial approximant of degree 8 in three parts, probably
- // to improve instruction-level parallelism.
- Packet16f y, y1, y2;
- y = pmadd(p16f_cephes_log_p0, x, p16f_cephes_log_p1);
- y1 = pmadd(p16f_cephes_log_p3, x, p16f_cephes_log_p4);
- y2 = pmadd(p16f_cephes_log_p6, x, p16f_cephes_log_p7);
- y = pmadd(y, x, p16f_cephes_log_p2);
- y1 = pmadd(y1, x, p16f_cephes_log_p5);
- y2 = pmadd(y2, x, p16f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- // Add the logarithm of the exponent back to the result of the interpolation.
- y1 = pmul(e, p16f_cephes_log_q1);
- tmp = pmul(x2, p16f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p16f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
-
- // Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF.
- return _mm512_mask_blend_ps(iszero_mask, p16f_minus_inf,
- _mm512_mask_blend_ps(invalid_mask, p16f_nan, x));
+ return plog_float(_x);
}
-#endif
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
+plog<Packet8d>(const Packet8d& _x) {
+ return plog_double(_x);
+}
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
+plog2<Packet16f>(const Packet16f& _x) {
+ return plog2_float(_x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
+plog2<Packet8d>(const Packet8d& _x) {
+ return plog2_double(_x);
+}
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog2)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog2)
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
@@ -158,17 +99,17 @@ pexp<Packet16f>(const Packet16f& _x) {
_EIGEN_DECLARE_CONST_Packet16f(nln2, -0.6931471805599453f);
Packet16f r = _mm512_fmadd_ps(m, p16f_nln2, x);
Packet16f r2 = pmul(r, r);
+ Packet16f r3 = pmul(r2, r);
- // TODO(gonnet): Split into odd/even polynomials and try to exploit
- // instruction-level parallelism.
- Packet16f y = p16f_cephes_exp_p0;
- y = pmadd(y, r, p16f_cephes_exp_p1);
- y = pmadd(y, r, p16f_cephes_exp_p2);
- y = pmadd(y, r, p16f_cephes_exp_p3);
- y = pmadd(y, r, p16f_cephes_exp_p4);
- y = pmadd(y, r, p16f_cephes_exp_p5);
- y = pmadd(y, r2, r);
- y = padd(y, p16f_1);
+ // Evaluate the polynomial approximant,improved by instruction-level parallelism.
+ Packet16f y, y1, y2;
+ y = pmadd(p16f_cephes_exp_p0, r, p16f_cephes_exp_p1);
+ y1 = pmadd(p16f_cephes_exp_p3, r, p16f_cephes_exp_p4);
+ y2 = padd(r, p16f_1);
+ y = pmadd(y, r, p16f_cephes_exp_p2);
+ y1 = pmadd(y1, r, p16f_cephes_exp_p5);
+ y = pmadd(y, r3, y1);
+ y = pmadd(y, r2, y2);
// Build emm0 = 2^m.
Packet16i emm0 = _mm512_cvttps_epi32(padd(m, p16f_127));
@@ -178,74 +119,40 @@ pexp<Packet16f>(const Packet16f& _x) {
return pmax(pmul(y, _mm512_castsi512_ps(emm0)), _x);
}
-/*template <>
+template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
pexp<Packet8d>(const Packet8d& _x) {
- Packet8d x = _x;
-
- _EIGEN_DECLARE_CONST_Packet8d(1, 1.0);
- _EIGEN_DECLARE_CONST_Packet8d(2, 2.0);
-
- _EIGEN_DECLARE_CONST_Packet8d(exp_hi, 709.437);
- _EIGEN_DECLARE_CONST_Packet8d(exp_lo, -709.436139303);
-
- _EIGEN_DECLARE_CONST_Packet8d(cephes_LOG2EF, 1.4426950408889634073599);
-
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_p0, 1.26177193074810590878e-4);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_p1, 3.02994407707441961300e-2);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_p2, 9.99999999999999999910e-1);
-
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_q0, 3.00198505138664455042e-6);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_q1, 2.52448340349684104192e-3);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_q2, 2.27265548208155028766e-1);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_q3, 2.00000000000000000009e0);
-
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_C1, 0.693145751953125);
- _EIGEN_DECLARE_CONST_Packet8d(cephes_exp_C2, 1.42860682030941723212e-6);
-
- // clamp x
- x = pmax(pmin(x, p8d_exp_hi), p8d_exp_lo);
-
- // Express exp(x) as exp(g + n*log(2)).
- const Packet8d n =
- _mm512_mul_round_pd(p8d_cephes_LOG2EF, x, _MM_FROUND_TO_NEAREST_INT);
-
- // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
- // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
- // digits right.
- const Packet8d nC1 = pmul(n, p8d_cephes_exp_C1);
- const Packet8d nC2 = pmul(n, p8d_cephes_exp_C2);
- x = psub(x, nC1);
- x = psub(x, nC2);
-
- const Packet8d x2 = pmul(x, x);
+ return pexp_double(_x);
+}
- // Evaluate the numerator polynomial of the rational interpolant.
- Packet8d px = p8d_cephes_exp_p0;
- px = pmadd(px, x2, p8d_cephes_exp_p1);
- px = pmadd(px, x2, p8d_cephes_exp_p2);
- px = pmul(px, x);
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
- // Evaluate the denominator polynomial of the rational interpolant.
- Packet8d qx = p8d_cephes_exp_q0;
- qx = pmadd(qx, x2, p8d_cephes_exp_q1);
- qx = pmadd(qx, x2, p8d_cephes_exp_q2);
- qx = pmadd(qx, x2, p8d_cephes_exp_q3);
+template <>
+EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) {
+ Packet16f fexponent;
+ const Packet16h out = float2half(pfrexp<Packet16f>(half2float(a), fexponent));
+ exponent = float2half(fexponent);
+ return out;
+}
- // I don't really get this bit, copied from the SSE2 routines, so...
- // TODO(gonnet): Figure out what is going on here, perhaps find a better
- // rational interpolant?
- x = _mm512_div_pd(px, psub(qx, px));
- x = pmadd(p8d_2, x, p8d_1);
+template <>
+EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) {
+ return float2half(pldexp<Packet16f>(half2float(a), half2float(exponent)));
+}
- // Build e=2^n.
- const Packet8d e = _mm512_castsi512_pd(_mm512_slli_epi64(
- _mm512_add_epi64(_mm512_cvtpd_epi64(n), _mm512_set1_epi64(1023)), 52));
+template <>
+EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) {
+ Packet16f fexponent;
+ const Packet16bf out = F32ToBf16(pfrexp<Packet16f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
- // Construct the result 2^n * exp(g) = e * x. The max is used to catch
- // non-finite values in the input.
- return pmax(pmul(x, e), _x);
- }*/
+template <>
+EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) {
+ return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent)));
+}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -257,138 +164,197 @@ pexp<Packet8d>(const Packet8d& _x) {
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
psqrt<Packet16f>(const Packet16f& _x) {
- _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f);
- _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
-
- Packet16f neg_half = pmul(_x, p16f_minus_half);
+ Packet16f neg_half = pmul(_x, pset1<Packet16f>(-.5f));
+ __mmask16 denormal_mask = _mm512_kand(
+ _mm512_cmp_ps_mask(_x, pset1<Packet16f>((std::numeric_limits<float>::min)()),
+ _CMP_LT_OQ),
+ _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_GE_OQ));
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- __mmask16 non_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_GE_OQ);
- Packet16f x = _mm512_mask_blend_ps(non_zero_mask, _mm512_rsqrt14_ps(_x),
- _mm512_setzero_ps());
+ Packet16f x = _mm512_rsqrt14_ps(_x);
// Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p16f_one_point_five));
+ x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet16f>(1.5f)));
- // Multiply the original _x by it's reciprocal square root to extract the
- // square root.
- return pmul(_x, x);
+ // Flush results for denormals to zero.
+ return _mm512_mask_blend_ps(denormal_mask, pmul(_x,x), _mm512_setzero_ps());
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
psqrt<Packet8d>(const Packet8d& _x) {
- _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
- _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
- _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
-
- Packet8d neg_half = pmul(_x, p8d_minus_half);
+ Packet8d neg_half = pmul(_x, pset1<Packet8d>(-.5));
+ __mmask16 denormal_mask = _mm512_kand(
+ _mm512_cmp_pd_mask(_x, pset1<Packet8d>((std::numeric_limits<double>::min)()),
+ _CMP_LT_OQ),
+ _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_GE_OQ));
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- __mmask8 non_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_GE_OQ);
- Packet8d x = _mm512_mask_blend_pd(non_zero_mask, _mm512_rsqrt14_pd(_x),
- _mm512_setzero_pd());
+ Packet8d x = _mm512_rsqrt14_pd(_x);
- // Do a first step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8d_one_point_five));
+ // Do a single step of Newton's iteration.
+ x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet8d>(1.5)));
// Do a second step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8d_one_point_five));
+ x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet8d>(1.5)));
- // Multiply the original _x by it's reciprocal square root to extract the
- // square root.
- return pmul(_x, x);
+ return _mm512_mask_blend_pd(denormal_mask, pmul(_x,x), _mm512_setzero_pd());
}
#else
template <>
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
return _mm512_sqrt_ps(x);
}
+
template <>
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
return _mm512_sqrt_pd(x);
}
#endif
-// Functions for rsqrt.
-// Almost identical to the sqrt routine, just leave out the last multiplication
-// and fill in NaN/Inf where needed. Note that this function only exists as an
-// iterative version for doubles since there is no instruction for diretly
-// computing the reciprocal square root in AVX-512.
-#ifdef EIGEN_FAST_MATH
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
+
+// prsqrt for float.
+#if defined(EIGEN_VECTORIZE_AVX512ER)
+
+template <>
+EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
+ return _mm512_rsqrt28_ps(x);
+}
+#elif EIGEN_FAST_MATH
+
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
prsqrt<Packet16f>(const Packet16f& _x) {
_EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
_EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f);
_EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f);
- _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
Packet16f neg_half = pmul(_x, p16f_minus_half);
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- __mmask16 le_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_LT_OQ);
- Packet16f x = _mm512_mask_blend_ps(le_zero_mask, _mm512_setzero_ps(),
- _mm512_rsqrt14_ps(_x));
-
- // Fill in NaNs and Infs for the negative/zero entries.
- __mmask16 neg_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LT_OQ);
- Packet16f infs_and_nans = _mm512_mask_blend_ps(
- neg_mask, p16f_nan,
- _mm512_mask_blend_ps(le_zero_mask, p16f_inf, _mm512_setzero_ps()));
-
- // Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p16f_one_point_five));
+ // Identity infinite, negative and denormal arguments.
+ __mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ);
+ __mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ);
+ __mmask16 not_finite_pos_mask = not_pos_mask | inf_mask;
+
+ // Compute an approximate result using the rsqrt intrinsic, forcing +inf
+ // for denormals for consistency with AVX and SSE implementations.
+ Packet16f y_approx = _mm512_rsqrt14_ps(_x);
+
+ // Do a single step of Newton-Raphson iteration to improve the approximation.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet16f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p16f_one_point_five));
+
+ // Select the result of the Newton-Raphson step for positive finite arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
+ return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
+}
+#else
- // Insert NaNs and Infs in all the right places.
- return _mm512_mask_blend_ps(le_zero_mask, infs_and_nans, x);
+template <>
+EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
+ _EIGEN_DECLARE_CONST_Packet16f(one, 1.0f);
+ return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x));
}
+#endif
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
+// prsqrt for double.
+#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
prsqrt<Packet8d>(const Packet8d& _x) {
- _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(inf, 0x7ff0000000000000LL);
- _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(nan, 0x7ff1000000000000LL);
_EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
_EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
- _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
+ _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(inf, 0x7ff0000000000000LL);
Packet8d neg_half = pmul(_x, p8d_minus_half);
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- __mmask8 le_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_LT_OQ);
- Packet8d x = _mm512_mask_blend_pd(le_zero_mask, _mm512_setzero_pd(),
- _mm512_rsqrt14_pd(_x));
+ // Identity infinite, negative and denormal arguments.
+ __mmask8 inf_mask = _mm512_cmp_pd_mask(_x, p8d_inf, _CMP_EQ_OQ);
+ __mmask8 not_pos_mask = _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_LE_OQ);
+ __mmask8 not_finite_pos_mask = not_pos_mask | inf_mask;
- // Fill in NaNs and Infs for the negative/zero entries.
- __mmask8 neg_mask = _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_LT_OQ);
- Packet8d infs_and_nans = _mm512_mask_blend_pd(
- neg_mask, p8d_nan,
- _mm512_mask_blend_pd(le_zero_mask, p8d_inf, _mm512_setzero_pd()));
-
- // Do a first step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8d_one_point_five));
-
- // Do a second step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p8d_one_point_five));
-
- // Insert NaNs and Infs in all the right places.
- return _mm512_mask_blend_pd(le_zero_mask, infs_and_nans, x);
+ // Compute an approximate result using the rsqrt intrinsic, forcing +inf
+ // for denormals for consistency with AVX and SSE implementations.
+#if defined(EIGEN_VECTORIZE_AVX512ER)
+ Packet8d y_approx = _mm512_rsqrt28_pd(_x);
+#else
+ Packet8d y_approx = _mm512_rsqrt14_pd(_x);
+#endif
+ // Do one or two steps of Newton-Raphson's to improve the approximation, depending on the
+ // starting accuracy (either 2^-14 or 2^-28, depending on whether AVX512ER is available).
+ // The Newton-Raphson algorithm has quadratic convergence and roughly doubles the number
+ // of correct digits for each step.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet8d y_newton = pmul(y_approx, pmadd(neg_half, pmul(y_approx, y_approx), p8d_one_point_five));
+#if !defined(EIGEN_VECTORIZE_AVX512ER)
+ y_newton = pmul(y_newton, pmadd(y_newton, pmul(neg_half, y_newton), p8d_one_point_five));
+#endif
+ // Select the result of the Newton-Raphson step for positive finite arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
+ return _mm512_mask_blend_pd(not_finite_pos_mask, y_newton, y_approx);
}
#else
template <>
-EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
- return _mm512_rsqrt28_ps(x);
+EIGEN_STRONG_INLINE Packet8d prsqrt<Packet8d>(const Packet8d& x) {
+ _EIGEN_DECLARE_CONST_Packet8d(one, 1.0f);
+ return _mm512_div_pd(p8d_one, _mm512_sqrt_pd(x));
}
#endif
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet16f plog1p<Packet16f>(const Packet16f& _x) {
+ return generic_plog1p(_x);
+}
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
+ return generic_expm1(_x);
+}
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
+
#endif
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
+psin<Packet16f>(const Packet16f& _x) {
+ return psin_float(_x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
+pcos<Packet16f>(const Packet16f& _x) {
+ return pcos_float(_x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
+ptanh<Packet16f>(const Packet16f& _x) {
+ return internal::generic_fast_tanh_float(_x);
+}
+
+F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
+F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
+
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index f6500a16e..34d49ab66 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -19,10 +19,10 @@ namespace internal {
#endif
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
@@ -31,6 +31,8 @@ namespace internal {
typedef __m512 Packet16f;
typedef __m512i Packet16i;
typedef __m512d Packet8d;
+typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
+typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
template <>
struct is_arithmetic<__m512> {
@@ -45,6 +47,51 @@ struct is_arithmetic<__m512d> {
enum { value = true };
};
+template<> struct is_arithmetic<Packet16h> { enum { value = true }; };
+
+template <>
+struct packet_traits<half> : default_packet_traits {
+ typedef Packet16h type;
+ // There is no half-size packet for Packet16h.
+ typedef Packet16h half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 1,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasBessel = 1,
+ HasNdtri = 1
+ };
+};
+
template<> struct packet_traits<float> : default_packet_traits
{
typedef Packet16f type;
@@ -54,15 +101,32 @@ template<> struct packet_traits<float> : default_packet_traits
AlignedOnScalar = 1,
size = 16,
HasHalfPacket = 1,
-#if EIGEN_GNUC_AT_LEAST(5, 3)
-#ifdef EIGEN_VECTORIZE_AVX512DQ
+
+ HasAbs = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasBlend = 0,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
HasLog = 1,
-#endif
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
HasExp = 1,
- HasSqrt = 1,
- HasRsqrt = 1,
+ HasSqrt = EIGEN_FAST_MATH,
+ HasRsqrt = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
#endif
- HasDiv = 1
+ HasCmp = 1,
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -74,11 +138,18 @@ template<> struct packet_traits<double> : default_packet_traits
AlignedOnScalar = 1,
size = 8,
HasHalfPacket = 1,
-#if EIGEN_GNUC_AT_LEAST(5, 3)
- HasSqrt = 1,
+#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH,
#endif
- HasDiv = 1
+ HasCmp = 1,
+ HasDiv = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
};
};
@@ -98,19 +169,28 @@ template <>
struct unpacket_traits<Packet16f> {
typedef float type;
typedef Packet8f half;
- enum { size = 16, alignment=Aligned64 };
+ typedef Packet16i integer_packet;
+ typedef uint16_t mask_t;
+ enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true };
};
template <>
struct unpacket_traits<Packet8d> {
typedef double type;
typedef Packet4d half;
- enum { size = 8, alignment=Aligned64 };
+ enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
};
template <>
struct unpacket_traits<Packet16i> {
typedef int type;
typedef Packet8i half;
- enum { size = 16, alignment=Aligned64 };
+ enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false, masked_store_available=false };
+};
+
+template<>
+struct unpacket_traits<Packet16h> {
+ typedef Eigen::half type;
+ typedef Packet8h half;
+ enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
template <>
@@ -127,12 +207,39 @@ EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) {
}
template <>
+EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
+ return _mm512_castsi512_ps(_mm512_set1_epi32(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
+ return _mm512_castsi512_pd(_mm512_set1_epi64(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); }
+
+template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) {
+ return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
+ 0, -1, 0, -1, 0, -1, 0, -1));
+}
+template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) {
+ return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
+ 0, -1, 0, -1, 0, -1, 0, -1);
+}
+template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
+ return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1,
+ 0, 0, -1, -1, 0, 0, -1, -1));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));
}
template <>
EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
- return _mm512_broadcastsd_pd(_mm_load_pd1(from));
+ return _mm512_set1_pd(*from);
}
template <>
@@ -158,6 +265,11 @@ EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
const Packet8d& b) {
return _mm512_add_pd(a, b);
}
+template <>
+EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a,
+ const Packet16i& b) {
+ return _mm512_add_epi32(a, b);
+}
template <>
EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
@@ -169,6 +281,11 @@ EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a,
const Packet8d& b) {
return _mm512_sub_pd(a, b);
}
+template <>
+EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a,
+ const Packet16i& b) {
+ return _mm512_sub_epi32(a, b);
+}
template <>
EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) {
@@ -202,6 +319,11 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a,
const Packet8d& b) {
return _mm512_mul_pd(a, b);
}
+template <>
+EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a,
+ const Packet16i& b) {
+ return _mm512_mullo_epi32(a, b);
+}
template <>
EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a,
@@ -214,7 +336,7 @@ EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a,
return _mm512_div_pd(a, b);
}
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template <>
EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b,
const Packet16f& c) {
@@ -228,51 +350,216 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b,
#endif
template <>
+EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask,
+ const Packet16f& a,
+ const Packet16f& b) {
+ __mmask16 mask16 = _mm512_cmp_epi32_mask(
+ _mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
+ return _mm512_mask_blend_ps(mask16, a, b);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask,
+ const Packet8d& a,
+ const Packet8d& b) {
+ __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask),
+ _mm512_setzero_epi32(), _MM_CMPINT_EQ);
+ return _mm512_mask_blend_pd(mask8, a, b);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a,
const Packet16f& b) {
- return _mm512_min_ps(a, b);
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm512_min_ps(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a,
const Packet8d& b) {
- return _mm512_min_pd(a, b);
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm512_min_pd(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a,
const Packet16f& b) {
- return _mm512_max_ps(a, b);
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm512_max_ps(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a,
const Packet8d& b) {
- return _mm512_max_pd(a, b);
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm512_max_pd(b, a);
}
-template <>
-EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a,
- const Packet16f& b) {
+// Add specializations for min/max with prescribed NaN progation.
+template<>
+EIGEN_STRONG_INLINE Packet16f pmin<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet16f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8d pmin<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet8d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet16f pmax<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet16f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8d pmax<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet8d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet16f pmin<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet16f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8d pmin<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet8d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet16f pmax<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet16f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet8d pmax<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet8d>);
+}
+
+
#ifdef EIGEN_VECTORIZE_AVX512DQ
- return _mm512_and_ps(a, b);
+template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); }
+template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); }
+EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); }
#else
- Packet16f res = _mm512_undefined_ps();
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
- res = _mm512_insertf32x4(res, _mm_and_ps(lane0_a, lane0_b), 0);
+// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
+template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
+ return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_));
+}
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
- res = _mm512_insertf32x4(res, _mm_and_ps(lane1_a, lane1_b), 1);
+// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512
+template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
+ return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_));
+}
+
+EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
+ return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)),
+ _mm256_castps_si256(b),1));
+}
+#endif
+
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 32x16 to 16x16.
+EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
+ // Split data into small pieces and handle with AVX instructions
+ // to guarantee internal order of vector.
+ // Operation:
+ // dst[15:0] := Saturate16(rf[31:0])
+ // dst[31:16] := Saturate16(rf[63:32])
+ // ...
+ // dst[255:240] := Saturate16(rf[255:224])
+ __m256i lo = _mm256_castps_si256(extract256<0>(rf));
+ __m256i hi = _mm256_castps_si256(extract256<1>(rf));
+ __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
+ _mm256_extractf128_si256(lo, 1));
+ __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
+ _mm256_extractf128_si256(hi, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
+ return _mm512_castsi512_ps(
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
+}
+template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
+ return _mm512_castsi512_ps(
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
+ return _mm512_castsi512_ps(
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
+ return _mm512_castsi512_ps(
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
+ __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _CMP_EQ_OQ);
+ return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
+}
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
- res = _mm512_insertf32x4(res, _mm_and_ps(lane2_a, lane2_b), 2);
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
- res = _mm512_insertf32x4(res, _mm_and_ps(lane3_a, lane3_b), 3);
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
- return res;
+template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); }
+
+template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); }
+template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); }
+
+template <>
+EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
+ return _mm512_set1_epi32(0xffffffffu);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f ptrue<Packet16f>(const Packet16f& a) {
+ return _mm512_castsi512_ps(ptrue<Packet16i>(_mm512_castps_si512(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8d ptrue<Packet8d>(const Packet8d& a) {
+ return _mm512_castsi512_pd(ptrue<Packet16i>(_mm512_castpd_si512(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a,
+ const Packet16i& b) {
+ return _mm512_and_si512(a,b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a,
+ const Packet16f& b) {
+#ifdef EIGEN_VECTORIZE_AVX512DQ
+ return _mm512_and_ps(a, b);
+#else
+ return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b)));
#endif
}
template <>
@@ -288,35 +575,21 @@ EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a,
Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
- res = _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
-
- return res;
+ return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
#endif
}
+
+template <>
+EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) {
+ return _mm512_or_si512(a, b);
+}
+
template <>
-EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a,
- const Packet16f& b) {
+EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_or_ps(a, b);
#else
- Packet16f res = _mm512_undefined_ps();
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
- res = _mm512_insertf32x4(res, _mm_or_ps(lane0_a, lane0_b), 0);
-
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
- res = _mm512_insertf32x4(res, _mm_or_ps(lane1_a, lane1_b), 1);
-
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
- res = _mm512_insertf32x4(res, _mm_or_ps(lane2_a, lane2_b), 2);
-
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
- res = _mm512_insertf32x4(res, _mm_or_ps(lane3_a, lane3_b), 3);
-
- return res;
+ return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b)));
#endif
}
@@ -326,107 +599,80 @@ EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a,
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_or_pd(a, b);
#else
- Packet8d res = _mm512_undefined_pd();
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
- res = _mm512_insertf64x4(res, _mm256_or_pd(lane0_a, lane0_b), 0);
-
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
- res = _mm512_insertf64x4(res, _mm256_or_pd(lane1_a, lane1_b), 1);
-
- return res;
+ return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
#endif
}
template <>
-EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a,
- const Packet16f& b) {
+EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) {
+ return _mm512_xor_si512(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_xor_ps(a, b);
#else
- Packet16f res = _mm512_undefined_ps();
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane0_a, lane0_b), 0);
-
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane1_a, lane1_b), 1);
-
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane2_a, lane2_b), 2);
-
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane3_a, lane3_b), 3);
-
- return res;
+ return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b)));
#endif
}
+
template <>
-EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a,
- const Packet8d& b) {
+EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_xor_pd(a, b);
#else
- Packet8d res = _mm512_undefined_pd();
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
- res = _mm512_insertf64x4(res, _mm256_xor_pd(lane0_a, lane0_b), 0);
-
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
- res = _mm512_insertf64x4(res, _mm256_xor_pd(lane1_a, lane1_b), 1);
-
- return res;
+ return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
#endif
}
template <>
-EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a,
- const Packet16f& b) {
+EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) {
+ return _mm512_andnot_si512(b, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
- return _mm512_andnot_ps(a, b);
+ return _mm512_andnot_ps(b, a);
#else
- Packet16f res = _mm512_undefined_ps();
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane0_a, lane0_b), 0);
-
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane1_a, lane1_b), 1);
-
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane2_a, lane2_b), 2);
-
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane3_a, lane3_b), 3);
-
- return res;
+ return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b)));
#endif
}
template <>
-EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,
- const Packet8d& b) {
+EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
- return _mm512_andnot_pd(a, b);
+ return _mm512_andnot_pd(b, a);
#else
- Packet8d res = _mm512_undefined_pd();
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
- res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane0_a, lane0_b), 0);
+ return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
+#endif
+}
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
- res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane1_a, lane1_b), 1);
+template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
+ const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
+ return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a)
+{
+ // Work-around for default std::round rounding mode.
+ const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
+ const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
+ return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
- return res;
-#endif
+template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
+ return _mm512_srai_epi32(a, N);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) {
+ return _mm512_srli_epi32(a, N);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) {
+ return _mm512_slli_epi32(a, N);
}
template <>
@@ -457,79 +703,65 @@ EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
reinterpret_cast<const __m512i*>(from));
}
+template <>
+EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
+ __mmask16 mask = static_cast<__mmask16>(umask);
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
+}
+
// Loads 8 floats from memory a returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
template <>
EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) {
- Packet8f lane0 = _mm256_broadcast_ps((const __m128*)(const void*)from);
- // mimic an "inplace" permutation of the lower 128bits using a blend
- lane0 = _mm256_blend_ps(
- lane0, _mm256_castps128_ps256(_mm_permute_ps(
- _mm256_castps256_ps128(lane0), _MM_SHUFFLE(1, 0, 1, 0))),
- 15);
- // then we can perform a consistent permutation on the global register to get
- // everything in shape:
- lane0 = _mm256_permute_ps(lane0, _MM_SHUFFLE(3, 3, 2, 2));
-
- Packet8f lane1 = _mm256_broadcast_ps((const __m128*)(const void*)(from + 4));
- // mimic an "inplace" permutation of the lower 128bits using a blend
- lane1 = _mm256_blend_ps(
- lane1, _mm256_castps128_ps256(_mm_permute_ps(
- _mm256_castps256_ps128(lane1), _MM_SHUFFLE(1, 0, 1, 0))),
- 15);
- // then we can perform a consistent permutation on the global register to get
- // everything in shape:
- lane1 = _mm256_permute_ps(lane1, _MM_SHUFFLE(3, 3, 2, 2));
+ // an unaligned load is required here as there is no requirement
+ // on the alignment of input pointer 'from'
+ __m256i low_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
+ __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
+ __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
+ return pairs;
+}
#ifdef EIGEN_VECTORIZE_AVX512DQ
- Packet16f res = _mm512_undefined_ps();
- return _mm512_insertf32x8(res, lane0, 0);
- return _mm512_insertf32x8(res, lane1, 1);
- return res;
-#else
- Packet16f res = _mm512_undefined_ps();
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 0), 0);
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 1), 1);
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 0), 2);
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 1), 3);
- return res;
-#endif
-}
+// FIXME: this does not look optimal, better load a Packet4d and shuffle...
// Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3,
// a3}
template <>
EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
- Packet4d lane0 = _mm256_broadcast_pd((const __m128d*)(const void*)from);
- lane0 = _mm256_permute_pd(lane0, 3 << 2);
-
- Packet4d lane1 = _mm256_broadcast_pd((const __m128d*)(const void*)(from + 2));
- lane1 = _mm256_permute_pd(lane1, 3 << 2);
-
- Packet8d res = _mm512_undefined_pd();
- res = _mm512_insertf64x4(res, lane0, 0);
- return _mm512_insertf64x4(res, lane1, 1);
+ __m512d x = _mm512_setzero_pd();
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0);
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1);
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2);
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3);
+ return x;
+}
+#else
+template <>
+EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
+ __m512d x = _mm512_setzero_pd();
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0));
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1));
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2));
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3));
+ return x;
}
+#endif
// Loads 4 floats from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
template <>
EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) {
- Packet16f tmp = _mm512_undefined_ps();
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from), 0);
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 1), 1);
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 2), 2);
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 3), 3);
- return tmp;
+ Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from));
+ const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0);
+ return _mm512_permutexvar_ps(scatter_mask, tmp);
}
+
// Loads 2 doubles from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1}
template <>
EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) {
- Packet8d tmp = _mm512_undefined_pd();
- Packet2d tmp0 = _mm_load_pd1(from);
- Packet2d tmp1 = _mm_load_pd1(from + 1);
- Packet4d lane0 = _mm256_broadcastsd_pd(tmp0);
- Packet4d lane1 = _mm256_broadcastsd_pd(tmp1);
+ __m256d lane0 = _mm256_set1_pd(*from);
+ __m256d lane1 = _mm256_set1_pd(*(from+1));
+ __m512d tmp = _mm512_undefined_pd();
tmp = _mm512_insertf64x4(tmp, lane0, 0);
return _mm512_insertf64x4(tmp, lane1, 1);
}
@@ -561,11 +793,16 @@ EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512(
reinterpret_cast<__m512i*>(to), from);
}
+template <>
+EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16_t umask) {
+ __mmask16 mask = static_cast<__mmask16>(umask);
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
+}
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
Index stride) {
- Packet16i stride_vector = _mm512_set1_epi32(stride);
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier =
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
@@ -575,7 +812,7 @@ EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
template <>
EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from,
Index stride) {
- Packet8i stride_vector = _mm256_set1_epi32(stride);
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
@@ -586,7 +823,7 @@ template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
const Packet16f& from,
Index stride) {
- Packet16i stride_vector = _mm512_set1_epi32(stride);
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier =
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
@@ -596,7 +833,7 @@ template <>
EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to,
const Packet8d& from,
Index stride) {
- Packet8i stride_vector = _mm256_set1_epi32(stride);
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
_mm512_i32scatter_pd(to, indices, from, 8);
@@ -618,9 +855,9 @@ EIGEN_STRONG_INLINE void pstore1<Packet16i>(int* to, const int& a) {
pstore(to, pa);
}
-template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template <>
EIGEN_STRONG_INLINE float pfirst<Packet16f>(const Packet16f& a) {
@@ -648,20 +885,73 @@ template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a)
template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a)
{
// _mm512_abs_ps intrinsic not found, so hack around it
- return (__m512)_mm512_and_si512((__m512i)a, _mm512_set1_epi32(0x7fffffff));
+ return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff)));
}
template <>
EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
// _mm512_abs_ps intrinsic not found, so hack around it
- return (__m512d)_mm512_and_si512((__m512i)a,
- _mm512_set1_epi64(0x7fffffffffffffff));
+ return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a),
+ _mm512_set1_epi64(0x7fffffffffffffff)));
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
+ return pfrexp_generic(a, exponent);
+}
+
+// Extract exponent without existence of Packet8l.
+template<>
+EIGEN_STRONG_INLINE
+Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
+ const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
+ return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
+ #else
+ return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
+ #endif
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
+ return pldexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet8d max_exponent = pset1<Packet8d>(2099.0);
+ const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+
+ // Split 2^e into four factors and multiply.
+ const Packet8i bias = pset1<Packet8i>(1023);
+ Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4)
+
+ // 2^b
+ const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
+ Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
+ Packet8i lo = _mm256_slli_epi64(hi, 52);
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
+ Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
+ Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+
+ // 2^(e - 3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
+ lo = _mm256_slli_epi64(hi, 52);
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
+ c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
+ out = pmul(out, c); // a * 2^e
+ return out;
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
- __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0) __m256 OUTPUT##_1 = \
- _mm512_extractf32x8_ps(INPUT, 1)
+ __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
+ __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1)
#else
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm256_insertf128_ps( \
@@ -674,258 +964,64 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
- OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTA, 0); \
- OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTB, 1);
+ OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1);
#else
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
+ OUTPUT = _mm512_undefined_ps(); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
#endif
-template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f*
-vecs)
-{
- EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14);
- EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15);
-
- __m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0);
- __m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0);
- __m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0);
- __m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0);
-
- __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
-
- __m256 sum1 = _mm256_add_ps(perm1, hsum5);
- __m256 sum2 = _mm256_add_ps(perm2, hsum6);
- __m256 sum3 = _mm256_add_ps(perm3, hsum7);
- __m256 sum4 = _mm256_add_ps(perm4, hsum8);
-
- __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
-
- __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
-
- hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1);
- hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1);
- hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1);
- hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1);
-
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
-
- sum1 = _mm256_add_ps(perm1, hsum5);
- sum2 = _mm256_add_ps(perm2, hsum6);
- sum3 = _mm256_add_ps(perm3, hsum7);
- sum4 = _mm256_add_ps(perm4, hsum8);
-
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
-
- final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0));
-
- hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0);
- hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0);
- hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0);
- hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0);
-
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
-
- sum1 = _mm256_add_ps(perm1, hsum5);
- sum2 = _mm256_add_ps(perm2, hsum6);
- sum3 = _mm256_add_ps(perm3, hsum7);
- sum4 = _mm256_add_ps(perm4, hsum8);
-
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
-
- __m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0);
-
- hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1);
- hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1);
- hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1);
- hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1);
-
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
-
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
-
- sum1 = _mm256_add_ps(perm1, hsum5);
- sum2 = _mm256_add_ps(perm2, hsum6);
- sum3 = _mm256_add_ps(perm3, hsum7);
- sum4 = _mm256_add_ps(perm4, hsum8);
-
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
-
- final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0));
-
- __m512 final_output;
-
- EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1);
- return final_output;
-}
-
-template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
-{
- Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0);
- Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1);
-
- Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0);
- Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1);
-
- Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0);
- Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1);
-
- Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0);
- Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1);
-
- Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0);
- Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1);
-
- Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0);
- Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1);
-
- Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0);
- Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1);
-
- Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0);
- Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1);
-
- Packet4d tmp0, tmp1;
-
- tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
-
- tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
-
- __m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC);
-
- tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
-
- tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
-
- final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC));
-
- tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
-
- tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
-
- __m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC);
-
- tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1);
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
-
- tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1);
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
-
- final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC));
-
- __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0);
-
- return _mm512_insertf64x4(final_output, final_1, 1);
-}
template <>
EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) {
- //#ifdef EIGEN_VECTORIZE_AVX512DQ
-#if 0
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
- Packet8f sum = padd(lane0, lane1);
- Packet8f tmp0 = _mm256_hadd_ps(sum, _mm256_permute2f128_ps(a, a, 1));
- tmp0 = _mm256_hadd_ps(tmp0, tmp0);
- return pfirst(_mm256_hadd_ps(tmp0, tmp0));
+#ifdef EIGEN_VECTORIZE_AVX512DQ
+ __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
+ __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
+ Packet8f x = _mm256_add_ps(lane0, lane1);
+ return predux<Packet8f>(x);
#else
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
- Packet4f sum = padd(padd(lane0, lane1), padd(lane2, lane3));
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
+ __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3));
sum = _mm_hadd_ps(sum, sum);
sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1));
- return pfirst(sum);
+ return _mm_cvtss_f32(sum);
#endif
}
template <>
EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) {
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
- Packet4d sum = padd(lane0, lane1);
- Packet4d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
- return pfirst(_mm256_hadd_pd(tmp0, tmp0));
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
+ __m256d sum = _mm256_add_pd(lane0, lane1);
+ __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
+ return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0)));
}
template <>
-EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) {
+EIGEN_STRONG_INLINE Packet8f predux_half_dowto4<Packet16f>(const Packet16f& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
- return padd(lane0, lane1);
+ __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
+ __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
+ return _mm256_add_ps(lane0, lane1);
#else
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
- Packet4f sum0 = padd(lane0, lane2);
- Packet4f sum1 = padd(lane1, lane3);
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
+ __m128 sum0 = _mm_add_ps(lane0, lane2);
+ __m128 sum1 = _mm_add_ps(lane1, lane3);
return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
#endif
}
template <>
-EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) {
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
- Packet4d res = padd(lane0, lane1);
- return res;
+EIGEN_STRONG_INLINE Packet4d predux_half_dowto4<Packet8d>(const Packet8d& a) {
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
+ return _mm256_add_pd(lane0, lane1);
}
template <>
@@ -939,108 +1035,70 @@ EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) {
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#else
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
- Packet4f res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
+ __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#endif
}
template <>
EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) {
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
- Packet4d res = pmul(lane0, lane1);
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
+ __m256d res = pmul(lane0, lane1);
res = pmul(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1)));
}
template <>
EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) {
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
- Packet4f res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
+ __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
}
template <>
EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) {
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
- Packet4d res = _mm256_min_pd(lane0, lane1);
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
+ __m256d res = _mm256_min_pd(lane0, lane1);
res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1)));
}
template <>
EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) {
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
- Packet4f res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
+ __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
}
+
template <>
EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) {
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
- Packet4d res = _mm256_max_pd(lane0, lane1);
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
+ __m256d res = _mm256_max_pd(lane0, lane1);
res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1)));
}
-template <int Offset>
-struct palign_impl<Offset, Packet16f> {
- static EIGEN_STRONG_INLINE void run(Packet16f& first,
- const Packet16f& second) {
- if (Offset != 0) {
- __m512i first_idx = _mm512_set_epi32(
- Offset + 15, Offset + 14, Offset + 13, Offset + 12, Offset + 11,
- Offset + 10, Offset + 9, Offset + 8, Offset + 7, Offset + 6,
- Offset + 5, Offset + 4, Offset + 3, Offset + 2, Offset + 1, Offset);
-
- __m512i second_idx =
- _mm512_set_epi32(Offset - 1, Offset - 2, Offset - 3, Offset - 4,
- Offset - 5, Offset - 6, Offset - 7, Offset - 8,
- Offset - 9, Offset - 10, Offset - 11, Offset - 12,
- Offset - 13, Offset - 14, Offset - 15, Offset - 16);
-
- unsigned short mask = 0xFFFF;
- mask <<= (16 - Offset);
-
- first = _mm512_permutexvar_ps(first_idx, first);
- Packet16f tmp = _mm512_permutexvar_ps(second_idx, second);
- first = _mm512_mask_blend_ps(mask, first, tmp);
- }
- }
-};
-template <int Offset>
-struct palign_impl<Offset, Packet8d> {
- static EIGEN_STRONG_INLINE void run(Packet8d& first, const Packet8d& second) {
- if (Offset != 0) {
- __m512i first_idx = _mm512_set_epi32(
- 0, Offset + 7, 0, Offset + 6, 0, Offset + 5, 0, Offset + 4, 0,
- Offset + 3, 0, Offset + 2, 0, Offset + 1, 0, Offset);
-
- __m512i second_idx = _mm512_set_epi32(
- 0, Offset - 1, 0, Offset - 2, 0, Offset - 3, 0, Offset - 4, 0,
- Offset - 5, 0, Offset - 6, 0, Offset - 7, 0, Offset - 8);
-
- unsigned char mask = 0xFF;
- mask <<= (8 - Offset);
-
- first = _mm512_permutexvar_pd(first_idx, first);
- Packet8d tmp = _mm512_permutexvar_pd(second_idx, second);
- first = _mm512_mask_blend_pd(mask, first, tmp);
- }
- }
-};
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x)
+{
+ Packet16i xi = _mm512_castps_si512(x);
+ __mmask16 tmp = _mm512_test_epi32_mask(xi,xi);
+ return !_mm512_kortestz(tmp,tmp);
+}
+
#define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
@@ -1302,11 +1360,940 @@ EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/,
return Packet16f();
}
template <>
-EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& /*ifPacket*/,
- const Packet8d& /*thenPacket*/,
- const Packet8d& /*elsePacket*/) {
- assert(false && "To be implemented");
- return Packet8d();
+EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket,
+ const Packet8d& thenPacket,
+ const Packet8d& elsePacket) {
+ __mmask8 m = (ifPacket.select[0] )
+ | (ifPacket.select[1]<<1)
+ | (ifPacket.select[2]<<2)
+ | (ifPacket.select[3]<<3)
+ | (ifPacket.select[4]<<4)
+ | (ifPacket.select[5]<<5)
+ | (ifPacket.select[6]<<6)
+ | (ifPacket.select[7]<<7);
+ return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
+}
+
+// Packet math for Eigen::half
+template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
+ return _mm256_set1_epi16(from.x);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
+ return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
+ // (void*) -> workaround clang warning:
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
+ _mm256_store_si256((__m256i*)(void*)to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
+ // (void*) -> workaround clang warning:
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
+ _mm256_storeu_si256((__m256i*)(void*)to, from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h
+ploaddup<Packet16h>(const Eigen::half* from) {
+ unsigned short a = from[0].x;
+ unsigned short b = from[1].x;
+ unsigned short c = from[2].x;
+ unsigned short d = from[3].x;
+ unsigned short e = from[4].x;
+ unsigned short f = from[5].x;
+ unsigned short g = from[6].x;
+ unsigned short h = from[7].x;
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h
+ploadquad(const Eigen::half* from) {
+ unsigned short a = from[0].x;
+ unsigned short b = from[1].x;
+ unsigned short c = from[2].x;
+ unsigned short d = from[3].x;
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
+}
+
+EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm512_cvtph_ps(a);
+#else
+ EIGEN_ALIGN64 half aux[16];
+ pstore(aux, a);
+ float f0(aux[0]);
+ float f1(aux[1]);
+ float f2(aux[2]);
+ float f3(aux[3]);
+ float f4(aux[4]);
+ float f5(aux[5]);
+ float f6(aux[6]);
+ float f7(aux[7]);
+ float f8(aux[8]);
+ float f9(aux[9]);
+ float fa(aux[10]);
+ float fb(aux[11]);
+ float fc(aux[12]);
+ float fd(aux[13]);
+ float fe(aux[14]);
+ float ff(aux[15]);
+
+ return _mm512_set_ps(
+ ff, fe, fd, fc, fb, fa, f9, f8, f7, f6, f5, f4, f3, f2, f1, f0);
+#endif
+}
+
+EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) {
+#ifdef EIGEN_HAS_FP16_C
+ return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
+#else
+ EIGEN_ALIGN64 float aux[16];
+ pstore(aux, a);
+ half h0(aux[0]);
+ half h1(aux[1]);
+ half h2(aux[2]);
+ half h3(aux[3]);
+ half h4(aux[4]);
+ half h5(aux[5]);
+ half h6(aux[6]);
+ half h7(aux[7]);
+ half h8(aux[8]);
+ half h9(aux[9]);
+ half ha(aux[10]);
+ half hb(aux[11]);
+ half hc(aux[12]);
+ half hd(aux[13]);
+ half he(aux[14]);
+ half hf(aux[15]);
+
+ return _mm256_set_epi16(
+ hf.x, he.x, hd.x, hc.x, hb.x, ha.x, h9.x, h8.x,
+ h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
+ return ptrue(Packet8i(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a,
+ const Packet16h& b) {
+ return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
+ return float2half(plset<Packet16f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) {
+ // in some cases Packet8i is a wrapper around __m256i, so we need to
+ // cast to Packet8i to call the correct overload.
+ return por(Packet8i(a),Packet8i(b));
+}
+template<> EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a,const Packet16h& b) {
+ return pxor(Packet8i(a),Packet8i(b));
+}
+template<> EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a,const Packet16h& b) {
+ return pand(Packet8i(a),Packet8i(b));
+}
+template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet16h& b) {
+ return pandnot(Packet8i(a),Packet8i(b));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
+ return _mm256_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
+ return float2half(pround<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
+ return float2half(print<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
+ return float2half(pceil<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
+ return float2half(pfloor<Packet16f>(half2float(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
+ Packet16f af = half2float(a);
+ Packet16f bf = half2float(b);
+ return Pack32To16(pcmp_eq(af, bf));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_le(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) {
+ return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
+ Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
+ return _mm256_xor_si256(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ Packet16f af = half2float(a);
+ Packet16f bf = half2float(b);
+ Packet16f rf = padd(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ Packet16f af = half2float(a);
+ Packet16f bf = half2float(b);
+ Packet16f rf = psub(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ Packet16f af = half2float(a);
+ Packet16f bf = half2float(b);
+ Packet16f rf = pmul(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ Packet16f af = half2float(a);
+ Packet16f bf = half2float(b);
+ Packet16f rf = pdiv(af, bf);
+ return float2half(rf);
+}
+
+template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
+ Packet16f from_float = half2float(from);
+ return half(predux(from_float));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
+ Packet8h lane0 = _mm256_extractf128_si256(a, 0);
+ Packet8h lane1 = _mm256_extractf128_si256(a, 1);
+ return padd<Packet8h>(lane0, lane1);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_max<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) {
+ Packet16f af = half2float(a);
+ float reduced = predux_min<Packet16f>(af);
+ return Eigen::half(reduced);
+}
+
+template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) {
+ Packet16f from_float = half2float(from);
+ return half(predux_mul(from_float));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm256_insertf128_si256(
+ _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a,1),m)),
+ _mm_shuffle_epi8(_mm256_extractf128_si256(a,0),m), 1);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride)
+{
+ return _mm256_set_epi16(
+ from[15*stride].x, from[14*stride].x, from[13*stride].x, from[12*stride].x,
+ from[11*stride].x, from[10*stride].x, from[9*stride].x, from[8*stride].x,
+ from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x,
+ from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride)
+{
+ EIGEN_ALIGN64 half aux[16];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet16h,16>& kernel) {
+ __m256i a = kernel.packet[0];
+ __m256i b = kernel.packet[1];
+ __m256i c = kernel.packet[2];
+ __m256i d = kernel.packet[3];
+ __m256i e = kernel.packet[4];
+ __m256i f = kernel.packet[5];
+ __m256i g = kernel.packet[6];
+ __m256i h = kernel.packet[7];
+ __m256i i = kernel.packet[8];
+ __m256i j = kernel.packet[9];
+ __m256i k = kernel.packet[10];
+ __m256i l = kernel.packet[11];
+ __m256i m = kernel.packet[12];
+ __m256i n = kernel.packet[13];
+ __m256i o = kernel.packet[14];
+ __m256i p = kernel.packet[15];
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
+
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
+
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
+
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
+
+ kernel.packet[0] = a_p_0;
+ kernel.packet[1] = a_p_1;
+ kernel.packet[2] = a_p_2;
+ kernel.packet[3] = a_p_3;
+ kernel.packet[4] = a_p_4;
+ kernel.packet[5] = a_p_5;
+ kernel.packet[6] = a_p_6;
+ kernel.packet[7] = a_p_7;
+ kernel.packet[8] = a_p_8;
+ kernel.packet[9] = a_p_9;
+ kernel.packet[10] = a_p_a;
+ kernel.packet[11] = a_p_b;
+ kernel.packet[12] = a_p_c;
+ kernel.packet[13] = a_p_d;
+ kernel.packet[14] = a_p_e;
+ kernel.packet[15] = a_p_f;
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet16h,8>& kernel) {
+ EIGEN_ALIGN64 half in[8][16];
+ pstore<half>(in[0], kernel.packet[0]);
+ pstore<half>(in[1], kernel.packet[1]);
+ pstore<half>(in[2], kernel.packet[2]);
+ pstore<half>(in[3], kernel.packet[3]);
+ pstore<half>(in[4], kernel.packet[4]);
+ pstore<half>(in[5], kernel.packet[5]);
+ pstore<half>(in[6], kernel.packet[6]);
+ pstore<half>(in[7], kernel.packet[7]);
+
+ EIGEN_ALIGN64 half out[8][16];
+
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ out[i][j] = in[j][2*i];
+ }
+ for (int j = 0; j < 8; ++j) {
+ out[i][j+8] = in[j][2*i+1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet16h>(out[0]);
+ kernel.packet[1] = pload<Packet16h>(out[1]);
+ kernel.packet[2] = pload<Packet16h>(out[2]);
+ kernel.packet[3] = pload<Packet16h>(out[3]);
+ kernel.packet[4] = pload<Packet16h>(out[4]);
+ kernel.packet[5] = pload<Packet16h>(out[5]);
+ kernel.packet[6] = pload<Packet16h>(out[6]);
+ kernel.packet[7] = pload<Packet16h>(out[7]);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet16h,4>& kernel) {
+ EIGEN_ALIGN64 half in[4][16];
+ pstore<half>(in[0], kernel.packet[0]);
+ pstore<half>(in[1], kernel.packet[1]);
+ pstore<half>(in[2], kernel.packet[2]);
+ pstore<half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN64 half out[4][16];
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][4*i];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+4] = in[j][4*i+1];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+8] = in[j][4*i+2];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j+12] = in[j][4*i+3];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet16h>(out[0]);
+ kernel.packet[1] = pload<Packet16h>(out[1]);
+ kernel.packet[2] = pload<Packet16h>(out[2]);
+ kernel.packet[3] = pload<Packet16h>(out[3]);
+}
+
+template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet16bf type;
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 1,
+ HasBlend = 0,
+ HasInsert = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
+#ifdef EIGEN_VECTORIZE_AVX512DQ
+ HasLog = 1, // Currently fails test with bad accuracy.
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+#endif
+ HasExp = 1,
+ HasSqrt = EIGEN_FAST_MATH,
+ HasRsqrt = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+#endif
+ HasCmp = 1,
+ HasDiv = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet16bf>
+{
+ typedef bfloat16 type;
+ enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
+ typedef Packet8bf half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
+ return _mm256_set1_epi16(from.value);
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
+ bfloat16 t;
+ t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
+ return t;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
+ const Packet16bf& from) {
+ _mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
+ const Packet16bf& from) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf
+ploaddup<Packet16bf>(const bfloat16* from) {
+ Packet16bf r;
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ unsigned short e = from[4].value;
+ unsigned short f = from[5].value;
+ unsigned short g = from[6].value;
+ unsigned short h = from[7].value;
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf
+ploadquad(const bfloat16* from) {
+ Packet16bf r;
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
+}
+
+EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
+}
+
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
+EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
+ Packet16bf r;
+
+#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
+ // Since GCC 10.1 supports avx512bf16 and C style explicit cast
+ // (C++ static_cast is not supported yet), do converion via intrinsic
+ // and register path for performance.
+ r = (__m256i)(_mm512_cvtneps_pbh(a));
+
+#else
+ __m512i t;
+ __m512i input = _mm512_castps_si512(a);
+ __m512i nan = _mm512_set1_epi32(0x7fc0);
+
+ // uint32_t lsb = (input >> 16) & 1;
+ t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm512_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm512_srli_epi32(t, 16);
+
+ // Check NaN before converting back to bf16
+ __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
+
+ t = _mm512_mask_blend_epi32(mask, nan, t);
+ // output.value = static_cast<uint16_t>(input);
+ r = _mm512_cvtepi32_epi16(t);
+#endif // EIGEN_VECTORIZE_AVX512BF16
+
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
+ return ptrue<Packet8i>(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
+ return por<Packet8i>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
+ return pxor<Packet8i>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
+ return pand<Packet8i>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
+ const Packet16bf& b) {
+ return pandnot<Packet8i>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
+ const Packet16bf& a,
+ const Packet16bf& b) {
+ // Input mask is expected to be all 0/1, handle it with 8-bit
+ // intrinsic for performance.
+ return _mm256_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a)
+{
+ return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
+ const Packet16bf& b) {
+ return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
+ const Packet16bf& b) {
+ return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
+ const Packet16bf& b) {
+ return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
+ const Packet16bf& b) {
+ return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
+ Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
+ return _mm256_xor_si256(a, sign_mask);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
+ return _mm256_andnot_si256(sign_mask, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) {
+ return F32ToBf16(plset<Packet16f>(static_cast<float>(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
+ Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
+ return padd<Packet8bf>(lane0, lane1);
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
+ return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
+ __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,
+ 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+
+ Packet16bf res;
+ // Swap hi and lo first because shuffle is in 128-bit lanes.
+ res = _mm256_permute2x128_si256(a, a, 1);
+ // Shuffle 8-bit values in src within 2*128-bit lanes.
+ return _mm256_shuffle_epi8(res, m);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
+ Index stride) {
+ return _mm256_set_epi16(
+ from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
+ from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
+ from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
+ from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
+ const Packet16bf& from,
+ Index stride) {
+ EIGEN_ALIGN64 bfloat16 aux[16];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+ to[stride*8] = aux[8];
+ to[stride*9] = aux[9];
+ to[stride*10] = aux[10];
+ to[stride*11] = aux[11];
+ to[stride*12] = aux[12];
+ to[stride*13] = aux[13];
+ to[stride*14] = aux[14];
+ to[stride*15] = aux[15];
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
+ __m256i a = kernel.packet[0];
+ __m256i b = kernel.packet[1];
+ __m256i c = kernel.packet[2];
+ __m256i d = kernel.packet[3];
+ __m256i e = kernel.packet[4];
+ __m256i f = kernel.packet[5];
+ __m256i g = kernel.packet[6];
+ __m256i h = kernel.packet[7];
+ __m256i i = kernel.packet[8];
+ __m256i j = kernel.packet[9];
+ __m256i k = kernel.packet[10];
+ __m256i l = kernel.packet[11];
+ __m256i m = kernel.packet[12];
+ __m256i n = kernel.packet[13];
+ __m256i o = kernel.packet[14];
+ __m256i p = kernel.packet[15];
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
+
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
+
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
+
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
+ __m256i a = kernel.packet[0];
+ __m256i b = kernel.packet[1];
+ __m256i c = kernel.packet[2];
+ __m256i d = kernel.packet[3];
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
+ kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
+ kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
+ kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
new file mode 100644
index 000000000..330412729
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -0,0 +1,89 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_AVX512_H
+#define EIGEN_TYPE_CASTING_AVX512_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
+ return _mm512_cvttps_epi32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
+ return _mm512_cvtepi32_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
+ return _mm512_castps_si512(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
+ return _mm512_castsi512_ps(a);
+}
+
+template <>
+struct type_casting_traits<half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
+ return half2float(a);
+}
+
+template <>
+struct type_casting_traits<float, half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
+ return float2half(a);
+}
+
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
+ return Bf16ToF32(a);
+}
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
+ return F32ToBf16(a);
+}
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_AVX512_H
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index 67db2f8ee..f424f11cf 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -29,8 +29,54 @@ static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2d_MZERO, (P
//---------- float ----------
struct Packet2cf
{
- EIGEN_STRONG_INLINE explicit Packet2cf() : v(p4f_ZERO) {}
+ EIGEN_STRONG_INLINE explicit Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {}
+
+ EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b)
+ {
+ Packet4f v1, v2;
+
+ // Permute and multiply the real parts of a and b
+ v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD);
+ // Get the imaginary parts of a
+ v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN);
+ // multiply a_re * b
+ v1 = vec_madd(v1, b.v, p4f_ZERO);
+ // multiply a_im * b and get the conjugate result
+ v2 = vec_madd(v2, b.v, p4f_ZERO);
+ v2 = reinterpret_cast<Packet4f>(pxor(v2, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR)));
+ // permute back to a proper order
+ v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV);
+
+ return Packet2cf(padd<Packet4f>(v1, v2));
+ }
+
+ EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) {
+ v = pmul(Packet2cf(*this), b).v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const {
+ return Packet2cf(*this) *= b;
+ }
+
+ EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const {
+ return Packet2cf(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const {
+ return Packet2cf(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(void) const {
+ return Packet2cf(-v);
+ }
+
Packet4f v;
};
@@ -38,6 +84,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet2cf type;
typedef Packet2cf half;
+ typedef Packet4f as_real;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
@@ -60,7 +107,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; typedef Packet4f as_real; };
template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
@@ -80,16 +127,35 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
+EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
+{
+ Packet4f res0, res1;
+#ifdef __VSX__
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
+#ifdef _BIG_ENDIAN
+ __asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
+#else
+ __asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
+#endif
+#else
+ *reinterpret_cast<std::complex<float> *>(&res0) = *from0;
+ *reinterpret_cast<std::complex<float> *>(&res1) = *from1;
+ res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
+#endif
+ return Packet2cf(res0);
+}
+
template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
{
- std::complex<float> EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 std::complex<float> af[2];
af[0] = from[0*stride];
af[1] = from[1*stride];
return pload<Packet2cf>(af);
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
{
- std::complex<float> EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 std::complex<float> af[2];
pstore<std::complex<float> >((std::complex<float> *) af, from);
to[0*stride] = af[0];
to[1*stride] = af[1];
@@ -100,25 +166,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, con
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor<Packet4f>(a.v, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR))); }
-template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- Packet4f v1, v2;
-
- // Permute and multiply the real parts of a and b
- v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD);
- // Get the imaginary parts of a
- v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN);
- // multiply a_re * b
- v1 = vec_madd(v1, b.v, p4f_ZERO);
- // multiply a_im * b and get the conjugate result
- v2 = vec_madd(v2, b.v, p4f_ZERO);
- v2 = reinterpret_cast<Packet4f>(pxor(v2, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR)));
- // permute back to a proper order
- v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV);
-
- return Packet2cf(padd<Packet4f>(v1, v2));
-}
-
template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor<Packet4f>(a.v, b.v)); }
@@ -128,7 +175,7 @@ template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::co
template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
{
- std::complex<float> EIGEN_ALIGN16 res[2];
+ EIGEN_ALIGN16 std::complex<float> res[2];
pstore((float *)&res, a.v);
return res[0];
@@ -149,22 +196,6 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packe
return pfirst<Packet2cf>(Packet2cf(b));
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
-{
- Packet4f b1, b2;
-#ifdef _BIG_ENDIAN
- b1 = vec_sld(vecs[0].v, vecs[1].v, 8);
- b2 = vec_sld(vecs[1].v, vecs[0].v, 8);
-#else
- b1 = vec_sld(vecs[1].v, vecs[0].v, 8);
- b2 = vec_sld(vecs[0].v, vecs[1].v, 8);
-#endif
- b2 = vec_sld(b2, b2, 8);
- b2 = padd<Packet4f>(b1, b2);
-
- return Packet2cf(b2);
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
Packet4f b;
@@ -175,77 +206,12 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const P
return pfirst<Packet2cf>(prod);
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
-#ifdef _BIG_ENDIAN
- first.v = vec_sld(first.v, second.v, 8);
-#else
- first.v = vec_sld(second.v, first.v, 8);
-#endif
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
-template<> struct conj_helper<Packet4f, Packet2cf, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet4f& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet4f& x, const Packet2cf& y) const
- { return Packet2cf(internal::pmul<Packet4f>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet2cf, Packet4f, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet4f& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& x, const Packet4f& y) const
- { return Packet2cf(internal::pmul<Packet4f>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a, b);
+ Packet2cf res = pmul(a, pconj(b));
Packet4f s = pmul<Packet4f>(b.v, b.v);
return Packet2cf(pdiv(res.v, padd<Packet4f>(s, vec_perm(s, s, p16uc_COMPLEX32_REV))));
}
@@ -262,6 +228,11 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
kernel.packet[0].v = tmp;
}
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) {
+ Packet4f eq = reinterpret_cast<Packet4f>(vec_cmpeq(a.v,b.v));
+ return Packet2cf(vec_and(eq, vec_perm(eq, eq, p16uc_COMPLEX32_REV)));
+}
+
#ifdef __VSX__
template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
Packet2cf result;
@@ -270,12 +241,62 @@ template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, con
}
#endif
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a)
+{
+ return psqrt_complex<Packet2cf>(a);
+}
+
//---------- double ----------
#ifdef __VSX__
struct Packet1cd
{
EIGEN_STRONG_INLINE Packet1cd() {}
EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {}
+
+ EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b)
+ {
+ Packet2d a_re, a_im, v1, v2;
+
+ // Permute and multiply the real parts of a and b
+ a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI);
+ // Get the imaginary parts of a
+ a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO);
+ // multiply a_re * b
+ v1 = vec_madd(a_re, b.v, p2d_ZERO);
+ // multiply a_im * b and get the conjugate result
+ v2 = vec_madd(a_im, b.v, p2d_ZERO);
+ v2 = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(v2), reinterpret_cast<Packet4ui>(v2), 8));
+ v2 = pxor(v2, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR1));
+
+ return Packet1cd(padd<Packet2d>(v1, v2));
+ }
+
+ EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) {
+ v = pmul(Packet1cd(*this), b).v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const {
+ return Packet1cd(*this) *= b;
+ }
+
+ EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const {
+ return Packet1cd(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const {
+ return Packet1cd(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(void) const {
+ return Packet1cd(-v);
+ }
+
Packet2d v;
};
@@ -283,6 +304,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet1cd type;
typedef Packet1cd half;
+ typedef Packet2d as_real;
enum {
Vectorizable = 1,
AlignedOnScalar = 0,
@@ -302,7 +324,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; typedef Packet2d as_real; };
template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from) { return Packet1cd(pload<Packet2d>((const double*)from)); }
template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { return Packet1cd(ploadu<Packet2d>((const double*)from)); }
@@ -312,19 +334,13 @@ template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<
template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
{ /* here we really have to use unaligned loads :( */ return ploadu<Packet1cd>(&from); }
-template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index)
{
- std::complex<double> EIGEN_ALIGN16 af[2];
- af[0] = from[0*stride];
- af[1] = from[1*stride];
- return pload<Packet1cd>(af);
+ return pload<Packet1cd>(from);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index)
{
- std::complex<double> EIGEN_ALIGN16 af[2];
- pstore<std::complex<double> >(af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
+ pstore<std::complex<double> >(to, from);
}
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v + b.v); }
@@ -332,24 +348,6 @@ template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, con
template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR2))); }
-template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- Packet2d a_re, a_im, v1, v2;
-
- // Permute and multiply the real parts of a and b
- a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI);
- // Get the imaginary parts of a
- a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO);
- // multiply a_re * b
- v1 = vec_madd(a_re, b.v, p2d_ZERO);
- // multiply a_im * b and get the conjugate result
- v2 = vec_madd(a_im, b.v, p2d_ZERO);
- v2 = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(v2), reinterpret_cast<Packet4ui>(v2), 8));
- v2 = pxor(v2, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR1));
-
- return Packet1cd(padd<Packet2d>(v1, v2));
-}
-
template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pand(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(por(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pxor(a.v,b.v)); }
@@ -361,7 +359,7 @@ template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::c
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
- std::complex<double> EIGEN_ALIGN16 res[2];
+ EIGEN_ALIGN16 std::complex<double> res[2];
pstore<std::complex<double> >(res, a);
return res[0];
@@ -370,74 +368,15 @@ template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Pac
template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; }
template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs) { return vecs[0]; }
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-template<> struct conj_helper<Packet2d, Packet1cd, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet2d& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet2d& x, const Packet1cd& y) const
- { return Packet1cd(internal::pmul<Packet2d>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet1cd, Packet2d, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet2d& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& x, const Packet2d& y) const
- { return Packet1cd(internal::pmul<Packet2d>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for AltiVec
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
return Packet1cd(pdiv(res.v, padd<Packet2d>(s, vec_perm(s, s, p16uc_REVERSE64))));
}
@@ -453,6 +392,23 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
kernel.packet[0].v = tmp;
}
+
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) {
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a)==re(b), im(a)==im(b)]
+ Packet2d eq = reinterpret_cast<Packet2d>(vec_cmpeq(a.v,b.v));
+ // Swap real/imag elements in the mask in to get:
+ // [im(a)==im(b), re(a)==re(b)]
+ Packet2d eq_swapped = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(eq), reinterpret_cast<Packet4ui>(eq), 8));
+ // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cd(vec_and(eq, eq_swapped));
+}
+
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a)
+{
+ return psqrt_complex<Packet1cd>(a);
+}
+
#endif // __VSX__
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AltiVec/MathFunctions.h b/Eigen/src/Core/arch/AltiVec/MathFunctions.h
index c5e4bede7..3a7a32936 100644
--- a/Eigen/src/Core/arch/AltiVec/MathFunctions.h
+++ b/Eigen/src/Core/arch/AltiVec/MathFunctions.h
@@ -9,10 +9,6 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
- * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
- */
-
#ifndef EIGEN_MATH_FUNCTIONS_ALTIVEC_H
#define EIGEN_MATH_FUNCTIONS_ALTIVEC_H
@@ -20,180 +16,28 @@ namespace Eigen {
namespace internal {
-static _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
-static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-static _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
-static _EIGEN_DECLARE_CONST_Packet4i(23, 23);
-
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
-/* the smallest non denormalized float number */
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000);
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000); // -1.f/0.f
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_nan, 0xffffffff);
-
-/* natural logarithm computed for 4 simultaneous float
- return NaN for x <= 0
-*/
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
-static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
-
-#ifdef __VSX__
-static _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0);
-static _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0);
-static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
-
-static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, 709.437);
-static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -709.436139303);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
-
-#ifdef __POWER8_VECTOR__
-static Packet2l p2l_1023 = { 1023, 1023 };
-static Packet2ul p2ul_52 = { 52, 52 };
-#endif
-
-#endif
-
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
-
- Packet4i emm0;
-
- /* isvalid_mask is 0 if x < 0 or x is NaN. */
- Packet4ui isvalid_mask = reinterpret_cast<Packet4ui>(vec_cmpge(x, p4f_ZERO));
- Packet4ui iszero_mask = reinterpret_cast<Packet4ui>(vec_cmpeq(x, p4f_ZERO));
-
- x = pmax(x, p4f_min_norm_pos); /* cut off denormalized stuff */
- emm0 = vec_sr(reinterpret_cast<Packet4i>(x),
- reinterpret_cast<Packet4ui>(p4i_23));
-
- /* keep only the fractional part */
- x = pand(x, p4f_inv_mant_mask);
- x = por(x, p4f_half);
-
- emm0 = psub(emm0, p4i_0x7f);
- Packet4f e = padd(vec_ctf(emm0, 0), p4f_1);
-
- /* part2:
- if( x < SQRTHF ) {
- e -= 1;
- x = x + x - 1.0;
- } else { x = x - 1.0; }
- */
- Packet4f mask = reinterpret_cast<Packet4f>(vec_cmplt(x, p4f_cephes_SQRTHF));
- Packet4f tmp = pand(x, mask);
- x = psub(x, p4f_1);
- e = psub(e, pand(p4f_1, mask));
- x = padd(x, tmp);
-
- Packet4f x2 = pmul(x,x);
- Packet4f x3 = pmul(x2,x);
-
- Packet4f y, y1, y2;
- y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
- y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
- y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
- y = pmadd(y , x, p4f_cephes_log_p2);
- y1 = pmadd(y1, x, p4f_cephes_log_p5);
- y2 = pmadd(y2, x, p4f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- y1 = pmul(e, p4f_cephes_log_q1);
- tmp = pmul(x2, p4f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p4f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
- // negative arg will be NAN, 0 will be -INF
- x = vec_sel(x, p4f_minus_inf, iszero_mask);
- x = vec_sel(p4f_minus_nan, x, isvalid_mask);
- return x;
+ return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexp<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
-
- Packet4f tmp, fx;
- Packet4i emm0;
-
- // clamp x
- x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
-
- // express exp(x) as exp(g + n*log(2))
- fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
-
- fx = pfloor(fx);
-
- tmp = pmul(fx, p4f_cephes_exp_C1);
- Packet4f z = pmul(fx, p4f_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- z = pmul(x,x);
-
- Packet4f y = p4f_cephes_exp_p0;
- y = pmadd(y, x, p4f_cephes_exp_p1);
- y = pmadd(y, x, p4f_cephes_exp_p2);
- y = pmadd(y, x, p4f_cephes_exp_p3);
- y = pmadd(y, x, p4f_cephes_exp_p4);
- y = pmadd(y, x, p4f_cephes_exp_p5);
- y = pmadd(y, z, x);
- y = padd(y, p4f_1);
+ return pexp_float(_x);
+}
- // build 2^n
- emm0 = vec_cts(fx, 0);
- emm0 = vec_add(emm0, p4i_0x7f);
- emm0 = vec_sl(emm0, reinterpret_cast<Packet4ui>(p4i_23));
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f psin<Packet4f>(const Packet4f& _x)
+{
+ return psin_float(_x);
+}
- // Altivec's max & min operators just drop silent NaNs. Check NaNs in
- // inputs and return them unmodified.
- Packet4ui isnumber_mask = reinterpret_cast<Packet4ui>(vec_cmpeq(_x, _x));
- return vec_sel(_x, pmax(pmul(y, reinterpret_cast<Packet4f>(emm0)), _x),
- isnumber_mask);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f pcos<Packet4f>(const Packet4f& _x)
+{
+ return pcos_float(_x);
}
#ifndef EIGEN_COMP_CLANG
@@ -225,95 +69,19 @@ Packet2d psqrt<Packet2d>(const Packet2d& x)
return vec_sqrt(x);
}
-// VSX support varies between different compilers and even different
-// versions of the same compiler. For gcc version >= 4.9.3, we can use
-// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
-// a slow version that works with older compilers.
-// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
-// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
-static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
-#if EIGEN_GNUC_AT_LEAST(5, 4) || \
- (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
- return vec_cts(x, 0); // TODO: check clang version.
-#else
- double tmp[2];
- memcpy(tmp, &x, sizeof(tmp));
- Packet2l l = { static_cast<long long>(tmp[0]),
- static_cast<long long>(tmp[1]) };
- return l;
-#endif
-}
-
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d pexp<Packet2d>(const Packet2d& _x)
{
- Packet2d x = _x;
-
- Packet2d tmp, fx;
- Packet2l emm0;
-
- // clamp x
- x = pmax(pmin(x, p2d_exp_hi), p2d_exp_lo);
-
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(x, p2d_cephes_LOG2EF, p2d_half);
-
- fx = pfloor(fx);
-
- tmp = pmul(fx, p2d_cephes_exp_C1);
- Packet2d z = pmul(fx, p2d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet2d x2 = pmul(x,x);
-
- Packet2d px = p2d_cephes_exp_p0;
- px = pmadd(px, x2, p2d_cephes_exp_p1);
- px = pmadd(px, x2, p2d_cephes_exp_p2);
- px = pmul (px, x);
-
- Packet2d qx = p2d_cephes_exp_q0;
- qx = pmadd(qx, x2, p2d_cephes_exp_q1);
- qx = pmadd(qx, x2, p2d_cephes_exp_q2);
- qx = pmadd(qx, x2, p2d_cephes_exp_q3);
-
- x = pdiv(px,psub(qx,px));
- x = pmadd(p2d_2,x,p2d_1);
-
- // build 2^n
- emm0 = ConvertToPacket2l(fx);
-
-#ifdef __POWER8_VECTOR__
- emm0 = vec_add(emm0, p2l_1023);
- emm0 = vec_sl(emm0, p2ul_52);
-#else
- // Code is a bit complex for POWER7. There is actually a
- // vec_xxsldi intrinsic but it is not supported by some gcc versions.
- // So we shift (52-32) bits and do a word swap with zeros.
- _EIGEN_DECLARE_CONST_Packet4i(1023, 1023);
- _EIGEN_DECLARE_CONST_Packet4i(20, 20); // 52 - 32
-
- Packet4i emm04i = reinterpret_cast<Packet4i>(emm0);
- emm04i = vec_add(emm04i, p4i_1023);
- emm04i = vec_sl(emm04i, reinterpret_cast<Packet4ui>(p4i_20));
- static const Packet16uc perm = {
- 0x14, 0x15, 0x16, 0x17, 0x00, 0x01, 0x02, 0x03,
- 0x1c, 0x1d, 0x1e, 0x1f, 0x08, 0x09, 0x0a, 0x0b };
-#ifdef _BIG_ENDIAN
- emm0 = reinterpret_cast<Packet2l>(vec_perm(p4i_ZERO, emm04i, perm));
-#else
- emm0 = reinterpret_cast<Packet2l>(vec_perm(emm04i, p4i_ZERO, perm));
-#endif
-
+ return pexp_double(_x);
+}
#endif
- // Altivec's max & min operators just drop silent NaNs. Check NaNs in
- // inputs and return them unmodified.
- Packet2ul isnumber_mask = reinterpret_cast<Packet2ul>(vec_cmpeq(_x, _x));
- return vec_sel(_x, pmax(pmul(x, reinterpret_cast<Packet2d>(emm0)), _x),
- isnumber_mask);
+// Hyperbolic Tangent function.
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+ptanh<Packet4f>(const Packet4f& x) {
+ return internal::generic_fast_tanh_float(x);
}
-#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
new file mode 100644
index 000000000..3f79b97df
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -0,0 +1,2937 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
+// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
+#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
+
+#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
+#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
+#endif
+
+#include "MatrixProductCommon.h"
+
+// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
+#if EIGEN_COMP_LLVM
+#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY)
+#ifdef __MMA__
+#define EIGEN_ALTIVEC_MMA_ONLY
+#else
+#define EIGEN_ALTIVEC_DISABLE_MMA
+#endif
+#endif
+#endif
+
+#ifdef __has_builtin
+#if __has_builtin(__builtin_mma_assemble_acc)
+ #define ALTIVEC_MMA_SUPPORT
+#endif
+#endif
+
+#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ #include "MatrixProductMMA.h"
+#endif
+
+/**************************************************************************************************
+ * TODO *
+ * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
+ * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
+ **************************************************************************************************/
+namespace Eigen {
+
+namespace internal {
+
+/**************************
+ * Constants and typedefs *
+ **************************/
+template<typename Scalar>
+struct quad_traits
+{
+ typedef typename packet_traits<Scalar>::type vectortype;
+ typedef PacketBlock<vectortype,4> type;
+ typedef vectortype rhstype;
+ enum
+ {
+ vectorsize = packet_traits<Scalar>::size,
+ size = 4,
+ rows = 4
+ };
+};
+
+template<>
+struct quad_traits<double>
+{
+ typedef Packet2d vectortype;
+ typedef PacketBlock<vectortype,4> type;
+ typedef PacketBlock<Packet2d,2> rhstype;
+ enum
+ {
+ vectorsize = packet_traits<double>::size,
+ size = 2,
+ rows = 4
+ };
+};
+
+// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
+// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
+// are responsible to extract from convert between Eigen's and MatrixProduct approach.
+
+const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
+ 8, 9, 10, 11,
+ 16, 17, 18, 19,
+ 24, 25, 26, 27};
+
+const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
+ 12, 13, 14, 15,
+ 20, 21, 22, 23,
+ 28, 29, 30, 31};
+const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,ai],[b,bi] = [ai,bi]
+const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+/*********************************************
+ * Single precision real and complex packing *
+ * *******************************************/
+
+/**
+ * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves
+ * the diagonal real, whatever is below it is copied from the respective upper diagonal element and
+ * conjugated. There's no PanelMode available for symm packing.
+ *
+ * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using
+ * its respective rank-update instructions. The float32/64 versions are different because at this moment
+ * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements.
+ *
+ * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has
+ * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main
+ * reason why packing for complex is broken down into several different parts, also the reason why we endup having a
+ * float32/64 and complex float32/64 version.
+ **/
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
+{
+ std::complex<Scalar> v;
+ if(i < j)
+ {
+ v.real( dt(j,i).real());
+ v.imag(-dt(j,i).imag());
+ } else if(i > j)
+ {
+ v.real( dt(i,j).real());
+ v.imag( dt(i,j).imag());
+ } else {
+ v.real( dt(i,j).real());
+ v.imag((Scalar)0.0);
+ }
+ return v;
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * rows;
+ Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
+
+ Index rir = 0, rii, j = 0;
+ for(; j + vectorSize <= cols; j+=vectorSize)
+ {
+ rii = rir + vectorDelta;
+
+ for(Index i = k2; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
+
+ blockBf[rir + k] = v.real();
+ blockBf[rii + k] = v.imag();
+ }
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += vectorDelta;
+ }
+ if (j < cols)
+ {
+ rii = rir + ((cols - j) * rows);
+
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+
+ blockBf[rir] = v.real();
+ blockBf[rii] = v.imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
+{
+ const Index depth = cols;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * depth;
+ Scalar* blockAf = (Scalar *)(blockA);
+
+ Index rir = 0, rii, j = 0;
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ rii = rir + vectorDelta;
+
+ for(Index i = 0; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
+
+ blockAf[rir + k] = v.real();
+ blockAf[rii + k] = v.imag();
+ }
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += vectorDelta;
+ }
+
+ if (j < rows)
+ {
+ rii = rir + ((rows - j) * depth);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
+
+ blockAf[rir] = v.real();
+ blockAf[rii] = v.imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+
+ Index ri = 0, j = 0;
+ for(; j + N*vectorSize <= cols; j+=N*vectorSize)
+ {
+ Index i = k2;
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < N*vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockB[ri + k] = rhs(j+k, i);
+ else
+ blockB[ri + k] = rhs(i, j+k);
+ }
+ ri += N*vectorSize;
+ }
+ }
+
+ if (j < cols)
+ {
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ if(k <= i)
+ blockB[ri] = rhs(i, k);
+ else
+ blockB[ri] = rhs(k, i);
+ ri += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
+{
+ const Index depth = cols;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+
+ Index ri = 0, j = 0;
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockA[ri + k] = lhs(j+k, i);
+ else
+ blockA[ri + k] = lhs(i, j+k);
+ }
+ ri += vectorSize;
+ }
+ }
+
+ if (j < rows)
+ {
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if(i <= k)
+ blockA[ri] = lhs(k, i);
+ else
+ blockA[ri] = lhs(i, k);
+ ri += 1;
+ }
+ }
+ }
+}
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack std::complex<float64> ***********
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float32 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<float, Index, nr, StorageOrder>
+{
+ void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float64 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<double, Index, nr, StorageOrder>
+{
+ void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+/**
+ * PanelMode
+ * Packing might be called several times before being multiplied by gebp_kernel, this happens because
+ * on special occasions it fills part of block with other parts of the matrix. Two variables control
+ * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever
+ * is going to be the real offset and stride in the future and this is what you should obey. The process
+ * is to behave as you would with normal packing but leave the start of each part with the correct offset
+ * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride
+ * and offset and behaves accordingly.
+ **/
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
+{
+ const Index size = 16 / sizeof(Scalar);
+ pstore<Scalar>(to + (0 * size), block.packet[0]);
+ pstore<Scalar>(to + (1 * size), block.packet[1]);
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
+}
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
+{
+ const Index size = 16 / sizeof(Scalar);
+ pstore<Scalar>(to + (0 * size), block.packet[0]);
+ pstore<Scalar>(to + (1 * size), block.packet[1]);
+}
+
+// General template for lhs & rhs complex packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
+struct dhs_cpack {
+ EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
+ Scalar* blockAt = reinterpret_cast<Scalar *>(blockA);
+ Index j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,4> blockr, blocki;
+ PacketBlock<PacketC,8> cblock;
+
+ if (UseLhs) {
+ bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
+ } else {
+ bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
+ }
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
+ blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
+ blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
+ blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
+ blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
+ blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
+ blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ blocki.packet[2] = -blocki.packet[2];
+ blocki.packet[3] = -blocki.packet[3];
+ }
+
+ if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
+ {
+ ptranspose(blockr);
+ ptranspose(blocki);
+ }
+
+ storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
+ storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
+
+ rir += 4*vectorSize;
+ rii += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ PacketBlock<Packet,1> blockr, blocki;
+ PacketBlock<PacketC,2> cblock;
+
+ if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
+ {
+ if (UseLhs) {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
+ }
+ } else {
+ std::complex<Scalar> lhs0, lhs1;
+ if (UseLhs) {
+ lhs0 = lhs(j + 0, i);
+ lhs1 = lhs(j + 1, i);
+ cblock.packet[0] = pload2(&lhs0, &lhs1);
+ lhs0 = lhs(j + 2, i);
+ lhs1 = lhs(j + 3, i);
+ cblock.packet[1] = pload2(&lhs0, &lhs1);
+ } else {
+ lhs0 = lhs(i, j + 0);
+ lhs1 = lhs(i, j + 1);
+ cblock.packet[0] = pload2(&lhs0, &lhs1);
+ lhs0 = lhs(i, j + 2);
+ lhs1 = lhs(i, j + 3);
+ cblock.packet[1] = pload2(&lhs0, &lhs1);
+ }
+ }
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ }
+
+ pstore<Scalar>(blockAt + rir, blockr.packet[0]);
+ pstore<Scalar>(blockAt + rii, blocki.packet[0]);
+
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if (UseLhs) {
+ blockAt[rir] = lhs(k, i).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(k, i).imag();
+ else
+ blockAt[rii] = lhs(k, i).imag();
+ } else {
+ blockAt[rir] = lhs(i, k).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(i, k).imag();
+ else
+ blockAt[rii] = lhs(i, k).imag();
+ }
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs & rhs packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
+struct dhs_pack{
+ EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,4> block;
+
+ if (UseLhs) {
+ bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
+ } else {
+ bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
+ }
+ if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
+ {
+ ptranspose(block);
+ }
+
+ storeBlock<Scalar, Packet, Index>(blockA + ri, block);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
+ {
+ if (UseLhs) {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ blockA[ri+2] = lhs(j+2, i);
+ blockA[ri+3] = lhs(j+3, i);
+ } else {
+ blockA[ri+0] = lhs(i, j+0);
+ blockA[ri+1] = lhs(i, j+1);
+ blockA[ri+2] = lhs(i, j+2);
+ blockA[ri+3] = lhs(i, j+3);
+ }
+ } else {
+ Packet lhsV;
+ if (UseLhs) {
+ lhsV = lhs.template loadPacket<Packet>(j, i);
+ } else {
+ lhsV = lhs.template loadPacket<Packet>(i, j);
+ }
+ pstore<Scalar>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if (UseLhs) {
+ blockA[ri] = lhs(k, i);
+ } else {
+ blockA[ri] = lhs(i, k);
+ }
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d,2> block;
+ if(StorageOrder == RowMajor)
+ {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
+
+ ptranspose(block);
+ } else {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
+ }
+
+ storeBlock<double, Packet2d, Index>(blockA + ri, block);
+
+ ri += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == RowMajor)
+ {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ } else {
+ Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
+ pstore<double>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockA[ri] = lhs(k, i);
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for rhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += offset*(2*vectorSize);
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d,4> block;
+ if(StorageOrder == ColMajor)
+ {
+ PacketBlock<Packet2d,2> block1, block2;
+ block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
+ block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
+ block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
+ block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
+
+ ptranspose(block1);
+ ptranspose(block2);
+
+ pstore<double>(blockB + ri , block1.packet[0]);
+ pstore<double>(blockB + ri + 2, block2.packet[0]);
+ pstore<double>(blockB + ri + 4, block1.packet[1]);
+ pstore<double>(blockB + ri + 6, block2.packet[1]);
+ } else {
+ block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
+ block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
+ block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
+ block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
+
+ storeBlock<double, Packet2d, Index>(blockB + ri, block);
+ }
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == ColMajor)
+ {
+ blockB[ri+0] = rhs(i, j+0);
+ blockB[ri+1] = rhs(i, j+1);
+
+ ri += vectorSize;
+
+ blockB[ri+0] = rhs(i, j+2);
+ blockB[ri+1] = rhs(i, j+3);
+ } else {
+ Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
+ pstore<double>(blockB + ri, rhsV);
+
+ ri += vectorSize;
+
+ rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
+ pstore<double>(blockB + ri, rhsV);
+ }
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
+ }
+
+ if (j < cols)
+ {
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockB[ri] = rhs(i, k);
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs complex packing, float64 specialization.
+template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
+ double* blockAt = reinterpret_cast<double *>(blockA);
+ Index j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,2> blockr, blocki;
+ PacketBlock<PacketC,4> cblock;
+
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
+ blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
+ blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
+ }
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ }
+
+ storeBlock<double, Packet, Index>(blockAt + rir, blockr);
+ storeBlock<double, Packet, Index>(blockAt + rii, blocki);
+
+ rir += 2*vectorSize;
+ rii += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ PacketBlock<Packet,1> blockr, blocki;
+ PacketBlock<PacketC,2> cblock;
+
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ }
+
+ pstore<double>(blockAt + rir, blockr.packet[0]);
+ pstore<double>(blockAt + rii, blocki.packet[0]);
+
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockAt[rir] = lhs(k, i).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(k, i).imag();
+ else
+ blockAt[rii] = lhs(k, i).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for rhs complex packing, float64 specialization.
+template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
+ double* blockBt = reinterpret_cast<double *>(blockB);
+ Index j = 0;
+
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i < depth; i++)
+ {
+ PacketBlock<PacketC,4> cblock;
+ PacketBlock<Packet,2> blockr, blocki;
+
+ bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
+ blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ }
+
+ storeBlock<double, Packet, Index>(blockBt + rir, blockr);
+ storeBlock<double, Packet, Index>(blockBt + rii, blocki);
+
+ rir += 2*vectorSize;
+ rii += 2*vectorSize;
+ }
+
+ rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < cols)
+ {
+ if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockBt[rir] = rhs(i, k).real();
+
+ if(Conjugate)
+ blockBt[rii] = -rhs(i, k).imag();
+ else
+ blockBt[rii] = rhs(i, k).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+/**************
+ * GEMM utils *
+ **************/
+
+// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
+template<typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
+{
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
+}
+
+template<typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
+{
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ }
+}
+
+template<int N, typename Scalar, typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
+{
+ Packet lhsV = pload<Packet>(lhs);
+
+ pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+}
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
+#else
+ Index i = 0;
+ do {
+ lhsV[i] = lhs[i];
+ } while (++i < remaining_rows);
+#endif
+}
+
+template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+{
+ Packet lhsV;
+ loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
+
+ pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+}
+
+// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
+template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
+{
+ pger_common<Packet, false>(accReal, lhsV, rhsV);
+ if(LhsIsReal)
+ {
+ pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ EIGEN_UNUSED_VARIABLE(lhsVi);
+ } else {
+ if (!RhsIsReal) {
+ pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
+ pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhsVi);
+ }
+ pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
+ }
+}
+
+template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
+{
+ Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
+ Packet lhsVi;
+ if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+
+ pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
+}
+
+template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
+ if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+#else
+ Index i = 0;
+ do {
+ lhsV[i] = lhs_ptr[i];
+ if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
+ } while (++i < remaining_rows);
+ if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+#endif
+}
+
+template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
+{
+ Packet lhsV, lhsVi;
+ loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
+
+ pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
+}
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
+{
+ return ploadu<Packet>(lhs);
+}
+
+// Zero the accumulator on PacketBlock.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+ acc.packet[1] = pset1<Packet>((Scalar)0);
+ acc.packet[2] = pset1<Packet>((Scalar)0);
+ acc.packet[3] = pset1<Packet>((Scalar)0);
+}
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+}
+
+// Scale the PacketBlock vectors by alpha.
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+ acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
+ acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
+ acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+}
+
+// Complex version of PacketBlock scaling.
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
+{
+ bscalec_common<Packet>(cReal, aReal, bReal);
+
+ bscalec_common<Packet>(cImag, aImag, bReal);
+
+ pger_common<Packet, true>(&cReal, bImag, aImag.packet);
+
+ pger_common<Packet, false>(&cImag, bImag, aReal.packet);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
+{
+ acc.packet[0] = pand(acc.packet[0], pMask);
+ acc.packet[1] = pand(acc.packet[1], pMask);
+ acc.packet[2] = pand(acc.packet[2], pMask);
+ acc.packet[3] = pand(acc.packet[3], pMask);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
+{
+ band<Packet>(aReal, pMask);
+ band<Packet>(aImag, pMask);
+
+ bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
+}
+
+// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
+{
+ if (StorageOrder == RowMajor) {
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
+ } else {
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+ }
+}
+
+// An overload of bload when you have a PacketBLock with 8 vectors.
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
+{
+ if (StorageOrder == RowMajor) {
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
+ acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
+ acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
+ acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
+ acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
+ } else {
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+ acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
+ acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
+ acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
+ acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
+ }
+}
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
+{
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
+}
+
+const static Packet4i mask41 = { -1, 0, 0, 0 };
+const static Packet4i mask42 = { -1, -1, 0, 0 };
+const static Packet4i mask43 = { -1, -1, -1, 0 };
+
+const static Packet2l mask21 = { -1, 0 };
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows)
+{
+ if (remaining_rows == 0) {
+ return pset1<Packet>(float(0.0)); // Not used
+ } else {
+ switch (remaining_rows) {
+ case 1: return Packet(mask41);
+ case 2: return Packet(mask42);
+ default: return Packet(mask43);
+ }
+ }
+}
+
+template<>
+EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
+{
+ if (remaining_rows == 0) {
+ return pset1<Packet2d>(double(0.0)); // Not used
+ } else {
+ return Packet2d(mask21);
+ }
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
+{
+ band<Packet>(accZ, pMask);
+
+ bscale<Packet>(acc, accZ, pAlpha);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+{
+ pbroadcast4<Packet>(a, a0, a1, a2, a3);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
+{
+ a1 = pload<Packet2d>(a);
+ a3 = pload<Packet2d>(a + 2);
+ a0 = vec_splat(a1, 0);
+ a1 = vec_splat(a1, 1);
+ a2 = vec_splat(a3, 0);
+ a3 = vec_splat(a3, 1);
+}
+
+// PEEL loop factor.
+#define PEEL 7
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,1> &accZero,
+ Index remaining_rows,
+ Index remaining_cols)
+{
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,1> accZero;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ for(; k < depth; k++)
+ {
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+ }
+
+ accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
+ for(Index i = 0; i < remaining_rows; i++) {
+ res(row + i, col) += accZero.packet[0][i];
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows>
+EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,4> &accZero,
+ Index remaining_rows)
+{
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,4> accZero, acc;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
+
+ if ((remaining_depth == depth) && (rows >= accCols))
+ {
+ for(Index j = 0; j < 4; j++) {
+ acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
+ }
+ bscale<Packet>(acc, accZero, pAlpha, pMask);
+ res.template storePacketBlock<Packet,4>(row, col, acc);
+ } else {
+ for(; k < depth; k++)
+ {
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+ }
+
+ for(Index j = 0; j < 4; j++) {
+ accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
+ }
+ for(Index j = 0; j < 4; j++) {
+ for(Index i = 0; i < remaining_rows; i++) {
+ res(row + i, col + j) += accZero.packet[j][i];
+ }
+ }
+ }
+}
+
+#define MICRO_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_UNROLL_WORK(func, func2, peel) \
+ MICRO_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
+ func(4,peel) func(5,peel) func(6,peel) func(7,peel)
+
+#define MICRO_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
+ lhs_ptr##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ }
+
+#define MICRO_WORK_ONE(iter, peel) \
+ if (unroll_factor > iter) { \
+ pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
+ }
+
+#define MICRO_TYPE_PEEL4(func, func2, peel) \
+ if (PEEL > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ MICRO_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_TYPE_PEEL1(func, func2, peel) \
+ if (PEEL > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
+ MICRO_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ func(func1,func2,0); func(func1,func2,1); \
+ func(func1,func2,2); func(func1,func2,3); \
+ func(func1,func2,4); func(func1,func2,5); \
+ func(func1,func2,6); func(func1,func2,7); \
+ func(func1,func2,8); func(func1,func2,9);
+
+#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
+ Packet rhsV0[M]; \
+ func(func1,func2,0);
+
+#define MICRO_ONE_PEEL4 \
+ MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += (accRows * PEEL);
+
+#define MICRO_ONE4 \
+ MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += accRows;
+
+#define MICRO_ONE_PEEL1 \
+ MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += (remaining_cols * PEEL);
+
+#define MICRO_ONE1 \
+ MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += remaining_cols;
+
+#define MICRO_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzero<Scalar, Packet>(accZero##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ }
+
+#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
+
+#define MICRO_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
+
+#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
+
+#define MICRO_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
+ }
+
+#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
+
+#define MICRO_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
+ acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
+ acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
+ bscale<Packet>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
+ }
+
+#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
+
+#define MICRO_COL_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ bscale<Packet>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
+ }
+
+#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
+ PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,4> acc;
+
+ MICRO_SRC_PTR
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCH
+ MICRO_ONE_PEEL4
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_ONE4
+ }
+ MICRO_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
+ PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,1> acc;
+
+ MICRO_SRC_PTR
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCH
+ MICRO_ONE_PEEL1
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_ONE1
+ }
+ MICRO_COL_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows) {
+ gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_UNROLL
+}
+
+/****************
+ * GEMM kernels *
+ * **************/
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows) {
+ gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
+
+ if (remaining_rows > 0)
+ {
+ gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#define accColsC (accCols / 2)
+#define advanceRows ((LhsIsReal) ? 1 : 2)
+#define advanceCols ((RhsIsReal) ? 1 : 2)
+
+// PEEL_COMPLEX loop factor.
+#define PEEL_COMPLEX 3
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL(
+ const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
+ const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
+ PacketBlock<Packet,1> &accReal, PacketBlock<Packet,1> &accImag,
+ Index remaining_rows,
+ Index remaining_cols)
+{
+ Packet rhsV[1], rhsVi[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
+ if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
+ pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ rhs_ptr_real += remaining_cols;
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
+ const Scalar* lhs_ptr_imag;
+ if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ PacketBlock<Packet,1> accReal, accImag;
+ PacketBlock<Packet,1> taccReal, taccImag;
+ PacketBlock<Packetc,1> acc0, acc1;
+
+ bsetzero<Scalar, Packet>(accReal);
+ bsetzero<Scalar, Packet>(accImag);
+
+ Index remaining_depth = (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
+ if(!LhsIsReal) {
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
+ }
+ for (int l = 0; l < PEEL_COMPLEX; l++) {
+ MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
+ }
+
+ for(; k < depth; k++)
+ {
+ Packet rhsV[1], rhsVi[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
+ if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
+ pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ rhs_ptr_real += remaining_cols;
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+ }
+
+ bscalec<Packet,1>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+
+ if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
+ {
+ res(row + 0, col + 0) += pfirst<Packetc>(acc0.packet[0]);
+ } else {
+ acc0.packet[0] += res.template loadPacket<Packetc>(row + 0, col + 0);
+ res.template storePacketBlock<Packetc,1>(row + 0, col + 0, acc0);
+ if(remaining_rows > accColsC) {
+ res(row + accColsC, col + 0) += pfirst<Packetc>(acc1.packet[0]);
+ }
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
+ const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
+ const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
+ PacketBlock<Packet,4> &accReal, PacketBlock<Packet,4> &accImag,
+ Index remaining_rows)
+{
+ Packet rhsV[4], rhsVi[4];
+ pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ rhs_ptr_real += accRows;
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
+ const Scalar* lhs_ptr_imag;
+ if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ PacketBlock<Packet,4> accReal, accImag;
+ PacketBlock<Packet,4> taccReal, taccImag;
+ PacketBlock<Packetc,4> acc0, acc1;
+ PacketBlock<Packetc,8> tRes;
+
+ bsetzero<Scalar, Packet>(accReal);
+ bsetzero<Scalar, Packet>(accImag);
+
+ Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
+ if(!LhsIsReal) {
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
+ }
+ for (int l = 0; l < PEEL_COMPLEX; l++) {
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
+ }
+
+ if ((remaining_depth == depth) && (rows >= accCols))
+ {
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row, col);
+ bscalec<Packet>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1);
+ res.template storePacketBlock<Packetc,4>(row + 0, col, acc0);
+ res.template storePacketBlock<Packetc,4>(row + accColsC, col, acc1);
+ } else {
+ for(; k < depth; k++)
+ {
+ Packet rhsV[4], rhsVi[4];
+ pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ rhs_ptr_real += accRows;
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+ }
+
+ bscalec<Packet,4>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+
+ if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
+ {
+ for(Index j = 0; j < 4; j++) {
+ res(row + 0, col + j) += pfirst<Packetc>(acc0.packet[j]);
+ }
+ } else {
+ for(Index j = 0; j < 4; j++) {
+ PacketBlock<Packetc,1> acc2;
+ acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, col + j) + acc0.packet[j];
+ res.template storePacketBlock<Packetc,1>(row + 0, col + j, acc2);
+ if(remaining_rows > accColsC) {
+ res(row + accColsC, col + j) += pfirst<Packetc>(acc1.packet[j]);
+ }
+ }
+ }
+ }
+}
+
+#define MICRO_COMPLEX_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4)
+
+#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ MICRO_COMPLEX_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel)
+
+#define MICRO_COMPLEX_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
+ lhs_ptr_real##iter += accCols; \
+ if(!LhsIsReal) { \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
+ lhs_ptr_imag##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ }
+
+#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
+ if (unroll_factor > iter) { \
+ pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \
+ if (unroll_factor > iter) { \
+ pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
+ if (PEEL_COMPLEX > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ pbroadcast4_old<Packet>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ if(!RhsIsReal) { \
+ pbroadcast4_old<Packet>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \
+ if (PEEL_COMPLEX > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ rhsV##peel[0] = pset1<Packet>(rhs_ptr_real[remaining_cols * peel]); \
+ if(!RhsIsReal) { \
+ rhsVi##peel[0] = pset1<Packet>(rhs_ptr_imag[remaining_cols * peel]); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \
+ func(func1,func2,0); func(func1,func2,1); \
+ func(func1,func2,2); func(func1,func2,3); \
+ func(func1,func2,4); func(func1,func2,5); \
+ func(func1,func2,6); func(func1,func2,7); \
+ func(func1,func2,8); func(func1,func2,9);
+
+#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
+ Packet rhsV0[M], rhsVi0[M];\
+ func(func1,func2,0);
+
+#define MICRO_COMPLEX_ONE_PEEL4 \
+ MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += (accRows * PEEL_COMPLEX); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
+
+#define MICRO_COMPLEX_ONE4 \
+ MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += accRows; \
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+
+#define MICRO_COMPLEX_ONE_PEEL1 \
+ MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \
+ if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX);
+
+#define MICRO_COMPLEX_ONE1 \
+ MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += remaining_cols; \
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+
+#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzero<Scalar, Packet>(accReal##iter); \
+ bsetzero<Scalar, Packet>(accImag##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accReal##iter); \
+ EIGEN_UNUSED_VARIABLE(accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
+
+#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
+ if(!LhsIsReal) { \
+ lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ }
+
+#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
+
+#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
+ if(!LhsIsReal) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
+ } \
+ }
+
+#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
+
+#define MICRO_COMPLEX_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
+ bscalec<Packet,4>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
+ res.template storePacketBlock<Packetc,4>(row + iter*accCols + 0, col, acc0); \
+ res.template storePacketBlock<Packetc,4>(row + iter*accCols + accColsC, col, acc1); \
+ }
+
+#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
+
+#define MICRO_COMPLEX_COL_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
+ bscalec<Packet,1>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
+ res.template storePacketBlock<Packetc,1>(row + iter*accCols + 0, col, acc0); \
+ res.template storePacketBlock<Packetc,1>(row + iter*accCols + accColsC, col, acc1); \
+ }
+
+#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + accRows*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ PacketBlock<Packet,4> accReal0, accImag0, accReal1, accImag1;
+ PacketBlock<Packet,4> accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,4> accReal4, accImag4;
+ PacketBlock<Packet,4> taccReal, taccImag;
+ PacketBlock<Packetc,4> acc0, acc1;
+ PacketBlock<Packetc,8> tRes;
+
+ MICRO_COMPLEX_SRC_PTR
+ MICRO_COMPLEX_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_PREFETCH
+ MICRO_COMPLEX_ONE_PEEL4
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_ONE4
+ }
+ MICRO_COMPLEX_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + remaining_cols*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ PacketBlock<Packet,1> accReal0, accImag0, accReal1, accImag1;
+ PacketBlock<Packet,1> accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,1> accReal4, accImag4;
+ PacketBlock<Packet,1> taccReal, taccImag;
+ PacketBlock<Packetc,1> acc0, acc1;
+ PacketBlock<Packetc,2> tRes;
+
+ MICRO_COMPLEX_SRC_PTR
+ MICRO_COMPLEX_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_PREFETCH
+ MICRO_COMPLEX_ONE_PEEL1
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_ONE1
+ }
+ MICRO_COMPLEX_COL_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+#define MAX_COMPLEX_UNROLL 3
+ while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_col_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_UNROLL
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ const Scalar* blockA = (Scalar *) blockAc;
+ const Scalar* blockB = (Scalar *) blockBc;
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_COMPLEX_UNROLL 3
+ while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
+
+ if (remaining_rows > 0)
+ {
+ gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#undef accColsC
+#undef advanceCols
+#undef advanceRows
+
+/************************************
+ * ppc64le template specializations *
+ * **********************************/
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+#endif
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+#endif
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+// ********* gebp specializations *********
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<float>::vectortype Packet;
+ typedef typename quad_traits<float>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<double>::vectortype Packet;
+ typedef typename quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
new file mode 100644
index 000000000..33d543494
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -0,0 +1,221 @@
+//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines
+#ifdef EIGEN_POWER_USE_PREFETCH
+#define EIGEN_POWER_PREFETCH(p) prefetch(p)
+#else
+#define EIGEN_POWER_PREFETCH(p)
+#endif
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlpha);
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask);
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha);
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
+
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
+
+const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
+ 16, 17, 18, 19,
+ 4, 5, 6, 7,
+ 20, 21, 22, 23};
+
+const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
+ 24, 25, 26, 27,
+ 12, 13, 14, 15,
+ 28, 29, 30, 31};
+//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
+const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
+const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+
+// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+ acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
+ acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
+ acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
+ acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
+ acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
+ acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+}
+
+// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs)
+{
+ return ploadu<Packet>(rhs);
+}
+
+} // end namespace internal
+} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
new file mode 100644
index 000000000..6540c6fa6
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -0,0 +1,629 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
+// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+
+#pragma GCC target("cpu=power10")
+
+#ifdef __has_builtin
+#if !__has_builtin(__builtin_vsx_assemble_pair)
+#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
+#endif
+#endif
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
+{
+ __builtin_mma_xxsetaccz(acc);
+}
+
+template<typename DataMapper, typename Index, typename Packet, const Index accCols>
+EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
+{
+ PacketBlock<Packet, 4> result;
+ __builtin_mma_disassemble_acc(&result.packet, acc);
+
+ PacketBlock<Packet, 4> tRes;
+ bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
+
+ bscale<Packet>(tRes, result, alpha);
+
+ data.template storePacketBlock<Packet, 4>(i, j, tRes);
+}
+
+template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
+EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
+{
+ PacketBlock<Packet, 4> resultReal, resultImag;
+ __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
+ __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
+
+ PacketBlock<Packetc, 8> tRes;
+ bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
+
+ PacketBlock<Packet,4> taccReal, taccImag;
+ bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
+
+ PacketBlock<Packetc, 4> acc1, acc2;
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
+
+ data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
+ data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
+}
+
+// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
+{
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
+{
+ __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
+{
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
+{
+ // Just for compilation
+}
+
+template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
+{
+ pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
+ if(LhsIsReal) {
+ pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
+ } else {
+ if(!RhsIsReal) {
+ pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
+ pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhsVi);
+ }
+ pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
+ }
+}
+
+// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
+{
+ rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
+{
+ rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
+ rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
+{
+#if EIGEN_COMP_LLVM
+ __builtin_vsx_assemble_pair(&rhsV,
+ (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
+ (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
+#else
+ __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
+#endif
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
+{
+ // Just for compilation
+}
+
+// PEEL_MMA loop factor.
+#define PEEL_MMA 7
+
+#define MICRO_MMA_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_MMA_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
+ lhs_ptr##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ }
+
+#define MICRO_MMA_WORK_ONE(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
+ }
+
+#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
+ if (PEEL_MMA > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
+ MICRO_MMA_UNROLL(func2); \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
+ func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
+
+#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+ type rhsV0; \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,0);
+
+#define MICRO_MMA_ONE_PEEL \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr += (accRows * PEEL_MMA);
+
+#define MICRO_MMA_ONE \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr += accRows;
+
+#define MICRO_MMA_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ }
+
+#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
+
+#define MICRO_MMA_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
+
+#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
+
+#define MICRO_MMA_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
+ }
+
+#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
+
+#define MICRO_MMA_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
+ }
+
+#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
+ __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+
+ MICRO_MMA_SRC_PTR
+ MICRO_MMA_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_MMA_PREFETCH
+ MICRO_MMA_ONE_PEEL
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_MMA_ONE
+ }
+ MICRO_MMA_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ Index row = 0;
+#define MAX_MMA_UNROLL 7
+ while(row + MAX_MMA_UNROLL*accCols <= rows) {
+ gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_MMA_UNROLL > 7
+ case 7:
+ gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 6
+ case 6:
+ gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 5
+ case 5:
+ gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 4
+ case 4:
+ gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 3
+ case 3:
+ gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 2
+ case 2:
+ gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 1
+ case 1:
+ gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_MMA_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
+
+ if (remaining_rows > 0)
+ {
+ gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#define accColsC (accCols / 2)
+#define advanceRows ((LhsIsReal) ? 1 : 2)
+#define advanceCols ((RhsIsReal) ? 1 : 2)
+
+// PEEL_COMPLEX_MMA loop factor.
+#define PEEL_COMPLEX_MMA 7
+
+#define MICRO_COMPLEX_MMA_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4)
+
+#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
+ lhs_ptr_real##iter += accCols; \
+ if(!LhsIsReal) { \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
+ lhs_ptr_imag##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
+ if (PEEL_COMPLEX_MMA > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
+ if(!RhsIsReal) { \
+ ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_MMA_UNROLL(func2); \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
+ type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
+
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+ type rhsV0, rhsVi0; \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
+
+#define MICRO_COMPLEX_MMA_ONE_PEEL \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
+
+#define MICRO_COMPLEX_MMA_ONE \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr_real += accRows; \
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+
+#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
+ bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accReal##iter); \
+ EIGEN_UNUSED_VARIABLE(accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
+
+#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
+ if(!LhsIsReal) { \
+ lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
+
+#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
+ if(!LhsIsReal) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
+ } \
+ }
+
+#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
+
+#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + accRows*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
+
+ MICRO_COMPLEX_MMA_SRC_PTR
+ MICRO_COMPLEX_MMA_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_MMA_PREFETCH
+ MICRO_COMPLEX_MMA_ONE_PEEL
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_MMA_ONE
+ }
+ MICRO_COMPLEX_MMA_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ const Scalar* blockA = (Scalar *) blockAc;
+ const Scalar* blockB = (Scalar *) blockBc;
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_COMPLEX_MMA_UNROLL 4
+ while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_MMA_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_MMA_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
+
+ if (remaining_rows > 0)
+ {
+ gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#undef accColsC
+#undef advanceRows
+#undef advanceCols
+
+#pragma GCC reset_options
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index b3f1ea199..2a440545b 100755
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -22,31 +22,38 @@ namespace internal {
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
-#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#endif
-
// NOTE Altivec has 32 registers, but Eigen only accepts a value of 8 or 16
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
-typedef __vector float Packet4f;
-typedef __vector int Packet4i;
-typedef __vector unsigned int Packet4ui;
-typedef __vector __bool int Packet4bi;
-typedef __vector short int Packet8i;
-typedef __vector unsigned char Packet16uc;
+typedef __vector float Packet4f;
+typedef __vector int Packet4i;
+typedef __vector unsigned int Packet4ui;
+typedef __vector __bool int Packet4bi;
+typedef __vector short int Packet8s;
+typedef __vector unsigned short int Packet8us;
+typedef __vector signed char Packet16c;
+typedef __vector unsigned char Packet16uc;
+typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
// We don't want to write the same code all the time, but we need to reuse the constants
// and it doesn't really work to declare them global, so we define macros instead
-
#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
- Packet4f p4f_##NAME = reinterpret_cast<Packet4f>(vec_splat_s32(X))
+ Packet4f p4f_##NAME = {X, X, X, X}
#define _EIGEN_DECLARE_CONST_FAST_Packet4i(NAME,X) \
Packet4i p4i_##NAME = vec_splat_s32(X)
+#define _EIGEN_DECLARE_CONST_FAST_Packet4ui(NAME,X) \
+ Packet4ui p4ui_##NAME = {X, X, X, X}
+
+#define _EIGEN_DECLARE_CONST_FAST_Packet8us(NAME,X) \
+ Packet8us p8us_##NAME = {X, X, X, X, X, X, X, X}
+
+#define _EIGEN_DECLARE_CONST_FAST_Packet16uc(NAME,X) \
+ Packet16uc p16uc_##NAME = {X, X, X, X, X, X, X, X, X, X, X, X, X, X, X, X}
+
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -64,7 +71,7 @@ typedef __vector unsigned char Packet16uc;
#define DST_CHAN 1
#define DST_CTRL(size, count, stride) (((size) << 24) | ((count) << 16) | (stride))
-
+#define __UNPACK_TYPE__(PACKETNAME) typename unpacket_traits<PACKETNAME>::type
// These constants are endian-agnostic
static _EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0}
@@ -72,25 +79,36 @@ static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(ONE,1); //{ 1, 1, 1, 1}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
+static _EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u);
+static _EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu);
+static _EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1}
+static _EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1);
static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000}
#ifndef __VSX__
static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0}
#endif
-static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
-static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
+static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
+static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
+static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
+static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
+
+static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15};
+static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15};
static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 };
-static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
+static Packet16uc p16uc_REVERSE16 = { 14,15, 12,13, 10,11, 8,9, 6,7, 4,5, 2,3, 0,1 };
+static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 };
-// Mask alignment
-#ifdef __PPC64__
-#define _EIGEN_MASK_ALIGNMENT 0xfffffffffffffff0
-#else
-#define _EIGEN_MASK_ALIGNMENT 0xfffffff0
-#endif
+static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
+static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
+static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
+static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 };
+static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 };
-#define _EIGEN_ALIGNED_PTR(x) ((std::ptrdiff_t)(x) & _EIGEN_MASK_ALIGNMENT)
+static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
// Handle endianness properly while loading constants
// Define global static constants:
@@ -103,7 +121,7 @@ static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4u
static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16};
#else
-static Packet16uc p16uc_FORWARD = p16uc_REVERSE32;
+static Packet16uc p16uc_FORWARD = p16uc_REVERSE32;
static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 1), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 };
static Packet16uc p16uc_PSET32_WEVEN = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
@@ -129,27 +147,27 @@ static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_PSET64_HI, p16uc_PSET64_L
#define EIGEN_PPC_PREFETCH(ADDR) asm( " dcbt [%[addr]]\n" :: [addr] "r" (ADDR) : "cc" );
#endif
-template<> struct packet_traits<float> : default_packet_traits
-{
+template <>
+struct packet_traits<float> : default_packet_traits {
typedef Packet4f type;
typedef Packet4f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
+ size = 4,
HasHalfPacket = 1,
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasDiv = 1,
- HasMin = 1,
- HasMax = 1,
- HasAbs = 1,
- HasSin = 0,
- HasCos = 0,
- HasLog = 0,
- HasExp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasAbs = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
#ifdef __VSX__
HasSqrt = 1,
#if !EIGEN_COMP_CLANG
@@ -160,16 +178,62 @@ template<> struct packet_traits<float> : default_packet_traits
#else
HasSqrt = 0,
HasRsqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
#endif
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasNegate = 1,
HasBlend = 1
};
};
-template<> struct packet_traits<int> : default_packet_traits
-{
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasAbs = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+#ifdef __VSX__
+ HasSqrt = 1,
+#if !EIGEN_COMP_CLANG
+ HasRsqrt = 1,
+#else
+ HasRsqrt = 0,
+#endif
+#else
+ HasSqrt = 0,
+ HasRsqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+#endif
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasNegate = 1,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<int> : default_packet_traits {
typedef Packet4i type;
typedef Packet4i half;
enum {
@@ -178,6 +242,25 @@ template<> struct packet_traits<int> : default_packet_traits
size = 4,
HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<short int> : default_packet_traits {
+ typedef Packet8s type;
+ typedef Packet8s half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
HasAdd = 1,
HasSub = 1,
HasMul = 1,
@@ -186,9 +269,116 @@ template<> struct packet_traits<int> : default_packet_traits
};
};
+template <>
+struct packet_traits<unsigned short int> : default_packet_traits {
+ typedef Packet8us type;
+ typedef Packet8us half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<signed char> : default_packet_traits {
+ typedef Packet16c type;
+ typedef Packet16c half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<unsigned char> : default_packet_traits {
+ typedef Packet16uc type;
+ typedef Packet16uc half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template<> struct unpacket_traits<Packet4f>
+{
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet4i integer_packet;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet4i>
+{
+ typedef int type;
+ typedef Packet4i half;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8s>
+{
+ typedef short int type;
+ typedef Packet8s half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8us>
+{
+ typedef unsigned short int type;
+ typedef Packet8us half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+
+template<> struct unpacket_traits<Packet16c>
+{
+ typedef signed char type;
+ typedef Packet16c half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet16uc>
+{
+ typedef unsigned char type;
+ typedef Packet16uc half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet8bf>
+{
+ typedef bfloat16 type;
+ typedef Packet8bf half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+inline std::ostream & operator <<(std::ostream & s, const Packet16c & v)
+{
+ union {
+ Packet16c v;
+ signed char n[16];
+ } vt;
+ vt.v = v;
+ for (int i=0; i< 16; i++)
+ s << vt.n[i] << ", ";
+ return s;
+}
inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v)
{
@@ -198,7 +388,7 @@ inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v)
} vt;
vt.v = v;
for (int i=0; i< 16; i++)
- s << (int)vt.n[i] << ", ";
+ s << vt.n[i] << ", ";
return s;
}
@@ -235,122 +425,366 @@ inline std::ostream & operator <<(std::ostream & s, const Packet4ui & v)
return s;
}
-// Need to define them first or we get specialization after instantiation errors
-template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet pload_common(const __UNPACK_TYPE__(Packet)* from)
{
+ // some versions of GCC throw "unused-but-set-parameter".
+ // ignoring these warnings for now.
+ EIGEN_UNUSED_VARIABLE(from);
EIGEN_DEBUG_ALIGNED_LOAD
#ifdef __VSX__
- return vec_vsx_ld(0, from);
+ return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
#else
return vec_ld(0, from);
#endif
}
+// Need to define them first or we get specialization after instantiation errors
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{
+ return pload_common<Packet4f>(from);
+}
+
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
-#ifdef __VSX__
- return vec_vsx_ld(0, from);
-#else
- return vec_ld(0, from);
-#endif
+ return pload_common<Packet4i>(from);
}
-template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE Packet8s pload<Packet8s>(const short int* from)
+{
+ return pload_common<Packet8s>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pload<Packet8us>(const unsigned short int* from)
{
+ return pload_common<Packet8us>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pload<Packet16c>(const signed char* from)
+{
+ return pload_common<Packet16c>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const unsigned char* from)
+{
+ return pload_common<Packet16uc>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from)
+{
+ return pload_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){
+ // some versions of GCC throw "unused-but-set-parameter" (float *to).
+ // ignoring these warnings for now.
+ EIGEN_UNUSED_VARIABLE(to);
EIGEN_DEBUG_ALIGNED_STORE
#ifdef __VSX__
- vec_vsx_st(from, 0, to);
+ vec_xst(from, 0, to);
#else
vec_st(from, 0, to);
#endif
}
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{
+ pstore_common<Packet4f>(to, from);
+}
+
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
-#ifdef __VSX__
- vec_vsx_st(from, 0, to);
-#else
- vec_st(from, 0, to);
-#endif
+ pstore_common<Packet4i>(to, from);
}
-template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) {
- Packet4f v = {from, from, from, from};
+template<> EIGEN_STRONG_INLINE void pstore<short int>(short int* to, const Packet8s& from)
+{
+ pstore_common<Packet8s>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short int* to, const Packet8us& from)
+{
+ pstore_common<Packet8us>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
+{
+ pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<signed char>(signed char* to, const Packet16c& from)
+{
+ pstore_common<Packet16c>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<unsigned char>(unsigned char* to, const Packet16uc& from)
+{
+ pstore_common<Packet16uc>(to, from);
+}
+
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size4(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from};
return v;
}
-template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) {
- Packet4i v = {from, from, from, from};
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size8(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from, from, from, from, from};
return v;
}
-template<> EIGEN_STRONG_INLINE void
-pbroadcast4<Packet4f>(const float *a,
- Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size16(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from, from, from, from, from, from, from, from, from, from, from, from, from};
+ return v;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) {
+ return pset1_size4<Packet4f>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) {
+ return pset1_size4<Packet4i>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const short int& from) {
+ return pset1_size8<Packet8s>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pset1<Packet8us>(const unsigned short int& from) {
+ return pset1_size8<Packet8us>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pset1<Packet16c>(const signed char& from) {
+ return pset1_size16<Packet16c>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pset1<Packet16uc>(const unsigned char& from) {
+ return pset1_size16<Packet16uc>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) {
+ return reinterpret_cast<Packet4f>(pset1<Packet4i>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return pset1_size8<Packet8us>(reinterpret_cast<const unsigned short int&>(from));
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE void
+pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
{
- a3 = pload<Packet4f>(a);
+ a3 = pload<Packet>(a);
a0 = vec_splat(a3, 0);
a1 = vec_splat(a3, 1);
a2 = vec_splat(a3, 2);
a3 = vec_splat(a3, 3);
}
+
+template<> EIGEN_STRONG_INLINE void
+pbroadcast4<Packet4f>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ pbroadcast4_common<Packet4f>(a, a0, a1, a2, a3);
+}
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet4i>(const int *a,
Packet4i& a0, Packet4i& a1, Packet4i& a2, Packet4i& a3)
{
- a3 = pload<Packet4i>(a);
- a0 = vec_splat(a3, 0);
- a1 = vec_splat(a3, 1);
- a2 = vec_splat(a3, 2);
- a3 = vec_splat(a3, 3);
+ pbroadcast4_common<Packet4i>(a, a0, a1, a2, a3);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_common(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ return pload<Packet>(a);
}
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
- float EIGEN_ALIGN16 af[4];
- af[0] = from[0*stride];
- af[1] = from[1*stride];
- af[2] = from[2*stride];
- af[3] = from[3*stride];
- return pload<Packet4f>(af);
+ return pgather_common<Packet4f>(from, stride);
}
+
template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int, Packet4i>(const int* from, Index stride)
{
- int EIGEN_ALIGN16 ai[4];
- ai[0] = from[0*stride];
- ai[1] = from[1*stride];
- ai[2] = from[2*stride];
- ai[3] = from[3*stride];
- return pload<Packet4i>(ai);
+ return pgather_common<Packet4i>(from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_size8(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ a[4] = from[4*stride];
+ a[5] = from[5*stride];
+ a[6] = from[6*stride];
+ a[7] = from[7*stride];
+ return pload<Packet>(a);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8s pgather<short int, Packet8s>(const short int* from, Index stride)
+{
+ return pgather_size8<Packet8s>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8us pgather<unsigned short int, Packet8us>(const unsigned short int* from, Index stride)
+{
+ return pgather_size8<Packet8us>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ return pgather_size8<Packet8bf>(from, stride);
}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_size16(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ a[4] = from[4*stride];
+ a[5] = from[5*stride];
+ a[6] = from[6*stride];
+ a[7] = from[7*stride];
+ a[8] = from[8*stride];
+ a[9] = from[9*stride];
+ a[10] = from[10*stride];
+ a[11] = from[11*stride];
+ a[12] = from[12*stride];
+ a[13] = from[13*stride];
+ a[14] = from[14*stride];
+ a[15] = from[15*stride];
+ return pload<Packet>(a);
+}
+
+
+template<> EIGEN_DEVICE_FUNC inline Packet16c pgather<signed char, Packet16c>(const signed char* from, Index stride)
+{
+ return pgather_size16<Packet16c>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet16uc pgather<unsigned char, Packet16uc>(const unsigned char* from, Index stride)
+{
+ return pgather_size16<Packet16uc>(from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size4(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+}
+
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
- float EIGEN_ALIGN16 af[4];
- pstore<float>(af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
- to[2*stride] = af[2];
- to[3*stride] = af[3];
+ pscatter_size4<Packet4f>(to, from, stride);
}
+
template<> EIGEN_DEVICE_FUNC inline void pscatter<int, Packet4i>(int* to, const Packet4i& from, Index stride)
{
- int EIGEN_ALIGN16 ai[4];
- pstore<int>((int *)ai, from);
- to[0*stride] = ai[0];
- to[1*stride] = ai[1];
- to[2*stride] = ai[2];
- to[3*stride] = ai[3];
+ pscatter_size4<Packet4i>(to, from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size8(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+ to[4*stride] = a[4];
+ to[5*stride] = a[5];
+ to[6*stride] = a[6];
+ to[7*stride] = a[7];
}
-template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return pset1<Packet4f>(a) + p4f_COUNTDOWN; }
-template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return pset1<Packet4i>(a) + p4i_COUNTDOWN; }
-template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return a + b; }
-template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return a + b; }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<short int, Packet8s>(short int* to, const Packet8s& from, Index stride)
+{
+ pscatter_size8<Packet8s>(to, from, stride);
+}
-template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return a - b; }
-template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return a - b; }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<unsigned short int, Packet8us>(unsigned short int* to, const Packet8us& from, Index stride)
+{
+ pscatter_size8<Packet8us>(to, from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ pscatter_size8<Packet8bf>(to, from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size16(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+ to[4*stride] = a[4];
+ to[5*stride] = a[5];
+ to[6*stride] = a[6];
+ to[7*stride] = a[7];
+ to[8*stride] = a[8];
+ to[9*stride] = a[9];
+ to[10*stride] = a[10];
+ to[11*stride] = a[11];
+ to[12*stride] = a[12];
+ to[13*stride] = a[13];
+ to[14*stride] = a[14];
+ to[15*stride] = a[15];
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<signed char, Packet16c>(signed char* to, const Packet16c& from, Index stride)
+{
+ pscatter_size16<Packet16c>(to, from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<unsigned char, Packet16uc>(unsigned char* to, const Packet16uc& from, Index stride)
+{
+ pscatter_size16<Packet16uc>(to, from, stride);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return pset1<Packet4f>(a) + p4f_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return pset1<Packet4i>(a) + p4i_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet8s plset<Packet8s>(const short int& a) { return pset1<Packet8s>(a) + p8s_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet8us plset<Packet8us>(const unsigned short int& a) { return pset1<Packet8us>(a) + p8us_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet16c plset<Packet16c>(const signed char& a) { return pset1<Packet16c>(a) + p16c_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const unsigned char& a) { return pset1<Packet16uc>(a) + p16uc_COUNTDOWN; }
+
+template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui> (const Packet4ui& a, const Packet4ui& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet16uc padd<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return a + b; }
+
+template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f> (const Packet4f& a, const Packet4f& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i> (const Packet4i& a, const Packet4i& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet8s psub<Packet8s> (const Packet8s& a, const Packet8s& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet8us psub<Packet8us> (const Packet8us& a, const Packet8us& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet16c psub<Packet16c> (const Packet16c& a, const Packet16c& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet16uc psub<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return a - b; }
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return p4f_ZERO - a; }
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_ZERO - a; }
@@ -358,8 +792,13 @@ template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_madd(a,b, p4f_MZERO); }
-template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return a * b; }
+template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_madd(a,b, p4f_MZERO); }
+template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i> (const Packet4i& a, const Packet4i& b) { return a * b; }
+template<> EIGEN_STRONG_INLINE Packet8s pmul<Packet8s> (const Packet8s& a, const Packet8s& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmul<Packet8us> (const Packet8us& a, const Packet8us& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
@@ -387,85 +826,247 @@ template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& /*a*/, co
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_madd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return a*b + c; }
+template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c) { return vec_madd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c) { return vec_madd(a,b,c); }
-template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ #ifdef __VSX__
+ // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+ Packet4f ret;
+ __asm__ ("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ #else
+ return vec_min(a, b);
+ #endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmin<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); }
+
-template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ #ifdef __VSX__
+ // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
+ Packet4f ret;
+ __asm__ ("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ #else
+ return vec_max(a, b);
+ #endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmax<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmax<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmax<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_max(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) {
+ Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b));
+ return vec_nor(c,c);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_le(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_le(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_le(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmpeq(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf pand<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return pand<Packet8us>(a, b);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return por<Packet8us>(a, b);
+}
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return pxor<Packet8us>(a, b);
+}
-template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_andc(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_andc(a, b); }
-template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
+ return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
+{
+ Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a);
+ Packet4f res;
+
+#ifdef __VSX__
+ __asm__("xvrspiz %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (t));
+#else
+ __asm__("vrfiz %0, %1\n\t"
+ : "=v" (res)
+ : "v" (t));
+#endif
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) { return vec_ceil(a); }
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) { return vec_floor(a); }
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a)
+{
+ Packet4f res;
-#ifdef _BIG_ENDIAN
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+ __asm__("xvrspic %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (a));
+
+ return res;
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPACK_TYPE__(Packet)* from)
{
EIGEN_DEBUG_ALIGNED_LOAD
+#ifdef _BIG_ENDIAN
Packet16uc MSQ, LSQ;
Packet16uc mask;
MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword
LSQ = vec_ld(15, (unsigned char *)from); // least significant quadword
mask = vec_lvsl(0, from); // create the permute mask
- return (Packet4f) vec_perm(MSQ, LSQ, mask); // align the data
+ //TODO: Add static_cast here
+ return (Packet) vec_perm(MSQ, LSQ, mask); // align the data
+#else
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+{
+ return ploadu_common<Packet4f>(from);
}
template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
- // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
- Packet16uc MSQ, LSQ;
- Packet16uc mask;
- MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword
- LSQ = vec_ld(15, (unsigned char *)from); // least significant quadword
- mask = vec_lvsl(0, from); // create the permute mask
- return (Packet4i) vec_perm(MSQ, LSQ, mask); // align the data
+ return ploadu_common<Packet4i>(from);
}
-#else
-// We also need ot redefine little endian loading of Packet4i/Packet4f using VSX
-template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from)
+template<> EIGEN_STRONG_INLINE Packet8s ploadu<Packet8s>(const short int* from)
{
- EIGEN_DEBUG_UNALIGNED_LOAD
- return (Packet4i) vec_vsx_ld((long)from & 15, (const int*) _EIGEN_ALIGNED_PTR(from));
+ return ploadu_common<Packet8s>(from);
}
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const unsigned short int* from)
{
- EIGEN_DEBUG_UNALIGNED_LOAD
- return (Packet4f) vec_vsx_ld((long)from & 15, (const float*) _EIGEN_ALIGNED_PTR(from));
+ return ploadu_common<Packet8us>(from);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from)
+{
+ return ploadu_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const signed char* from)
+{
+ return ploadu_common<Packet16c>(from);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc ploadu<Packet16uc>(const unsigned char* from)
+{
+ return ploadu_common<Packet16uc>(from);
}
-#endif
-template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+template<typename Packet> EIGEN_STRONG_INLINE Packet ploaddup_common(const __UNPACK_TYPE__(Packet)* from)
{
- Packet4f p;
- if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet4f>(from);
- else p = ploadu<Packet4f>(from);
+ Packet p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet>(from);
+ else p = ploadu<Packet>(from);
return vec_perm(p, p, p16uc_DUPLICATE32_HI);
}
+template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+{
+ return ploaddup_common<Packet4f>(from);
+}
template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int* from)
{
- Packet4i p;
- if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet4i>(from);
- else p = ploadu<Packet4i>(from);
- return vec_perm(p, p, p16uc_DUPLICATE32_HI);
+ return ploaddup_common<Packet4i>(from);
}
-#ifdef _BIG_ENDIAN
-template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE Packet8s ploaddup<Packet8s>(const short int* from)
+{
+ Packet8s p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8s>(from);
+ else p = ploadu<Packet8s>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us ploaddup<Packet8us>(const unsigned short int* from)
+{
+ Packet8us p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8us>(from);
+ else p = ploadu<Packet8us>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s ploadquad<Packet8s>(const short int* from)
+{
+ Packet8s p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8s>(from);
+ else p = ploadu<Packet8s>(from);
+ return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const unsigned short int* from)
+{
+ Packet8us p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8us>(from);
+ else p = ploadu<Packet8us>(from);
+ return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from)
+{
+ return ploadquad<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const signed char* from)
+{
+ Packet16c p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16c>(from);
+ else p = ploadu<Packet16c>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc ploaddup<Packet16uc>(const unsigned char* from)
+{
+ Packet16uc p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16uc>(from);
+ else p = ploadu<Packet16uc>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE__(Packet)* to, const Packet& from)
{
EIGEN_DEBUG_UNALIGNED_STORE
+#ifdef _BIG_ENDIAN
// Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
// Warning: not thread safe!
Packet16uc MSQ, LSQ, edges;
@@ -479,45 +1080,69 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& f
MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ)
LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ)
vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
- vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part
+ vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second
+#else
+ vec_xst(from, 0, to);
+#endif
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+{
+ pstoreu_common<Packet4f>(to, from);
}
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from)
{
- EIGEN_DEBUG_UNALIGNED_STORE
- // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
- // Warning: not thread safe!
- Packet16uc MSQ, LSQ, edges;
- Packet16uc edgeAlign, align;
-
- MSQ = vec_ld(0, (unsigned char *)to); // most significant quadword
- LSQ = vec_ld(15, (unsigned char *)to); // least significant quadword
- edgeAlign = vec_lvsl(0, to); // permute map to extract edges
- edges=vec_perm(LSQ, MSQ, edgeAlign); // extract the edges
- align = vec_lvsr( 0, to ); // permute map to misalign data
- MSQ = vec_perm(edges, (Packet16uc) from, align); // misalign the data (MSQ)
- LSQ = vec_perm((Packet16uc) from, edges, align); // misalign the data (LSQ)
- vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
- vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part
+ pstoreu_common<Packet4i>(to, from);
}
-#else
-// We also need ot redefine little endian loading of Packet4i/Packet4f using VSX
-template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from)
+template<> EIGEN_STRONG_INLINE void pstoreu<short int>(short int* to, const Packet8s& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st(from, (long)to & 15, (int*) _EIGEN_ALIGNED_PTR(to));
+ pstoreu_common<Packet8s>(to, from);
}
-template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short int* to, const Packet8us& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st(from, (long)to & 15, (float*) _EIGEN_ALIGNED_PTR(to));
+ pstoreu_common<Packet8us>(to, from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
+{
+ pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<signed char>(signed char* to, const Packet16c& from)
+{
+ pstoreu_common<Packet16c>(to, from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<unsigned char>(unsigned char* to, const Packet16uc& from)
+{
+ pstoreu_common<Packet16uc>(to, from);
}
-#endif
template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_PPC_PREFETCH(addr); }
template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGEN_PPC_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x; vec_ste(a, 0, &x); return x; }
-template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int EIGEN_ALIGN16 x; vec_ste(a, 0, &x); return x; }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; }
+template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; }
+
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x;
+ vec_ste(a, 0, &x);
+ return x;
+}
+
+template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) {
+ return pfirst_common<Packet8s>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) {
+ return pfirst_common<Packet8us>(a);
+}
+
+template<> EIGEN_STRONG_INLINE signed char pfirst<Packet16c>(const Packet16c& a)
+{
+ return pfirst_common<Packet16c>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char pfirst<Packet16uc>(const Packet16uc& a)
+{
+ return pfirst_common<Packet16uc>(a);
+}
template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
{
@@ -525,10 +1150,296 @@ template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
}
template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
{
- return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32)); }
+ return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32));
+}
+template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a)
+{
+ return reinterpret_cast<Packet8s>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE16));
+}
+template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a)
+{
+ return reinterpret_cast<Packet8us>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE16));
+}
+template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a)
+{
+ return vec_perm(a, a, p16uc_REVERSE8);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
+{
+ return vec_perm(a, a, p16uc_REVERSE8);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ return preverse<Packet8us>(a);
+}
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ _EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF);
+ return pand<Packet8us>(p8us_abs_mask, a);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
+{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
+{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a)
+{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
+ return reinterpret_cast<Packet4f>(r);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
+ return reinterpret_cast<Packet4f>(r);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ return vec_sr(a, p4ui_mask);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ return vec_sl(a, p4ui_mask);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
+ return vec_sl(a, p8us_mask);
+}
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
+ return vec_sr(a, p8us_mask);
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
+ return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
+ return pand<Packet4f>(
+ reinterpret_cast<Packet4f>(bf.m_val),
+ reinterpret_cast<Packet4f>(p4ui_high_mask)
+ );
+}
+
+// Simple interleaving of bool masks, prevents true values from being
+// converted to NaNs.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) {
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
+ Packet4f bf_odd, bf_even;
+ bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd);
+ bf_even = plogical_shift_right<16>(even);
+ return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+}
+
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
+ Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
+ Packet4ui lsb = plogical_shift_right<16>(input);
+ lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
+
+ _EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu);
+ Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
+ input = padd<Packet4ui>(input, rounding_bias);
+
+ //Test NaN and Subnormal - Begin
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
+ Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
+ Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
+ Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
+ Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
+
+ Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
+ Packet4ui nan_selector = pandnot<Packet4ui>(
+ reinterpret_cast<Packet4ui>(is_max_exp),
+ reinterpret_cast<Packet4ui>(is_mant_zero)
+ );
+
+ Packet4ui subnormal_selector = pandnot<Packet4ui>(
+ reinterpret_cast<Packet4ui>(is_zero_exp),
+ reinterpret_cast<Packet4ui>(is_mant_zero)
+ );
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
+ input = vec_sel(input, p4ui_nan, nan_selector);
+ input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
+ //Test NaN and Subnormal - End
+
+ input = plogical_shift_right<16>(input);
+ return reinterpret_cast<Packet8us>(input);
+}
+
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
+ Packet4f bf_odd, bf_even;
+ bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
+ bf_odd = plogical_shift_left<16>(bf_odd);
+ bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
+ return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+}
+#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f op_even = OP(a_even);\
+ Packet4f op_odd = OP(a_odd);\
+ return F32ToBf16(op_even, op_odd);\
+
+#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f b_even = Bf16ToF32Even(B);\
+ Packet4f b_odd = Bf16ToF32Odd(B);\
+ Packet4f op_even = OP(a_even, b_even);\
+ Packet4f op_odd = OP(a_odd, b_odd);\
+ return F32ToBf16(op_even, op_odd);\
+
+#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f b_even = Bf16ToF32Even(B);\
+ Packet4f b_odd = Bf16ToF32Odd(B);\
+ Packet4f op_even = OP(a_even, b_even);\
+ Packet4f op_odd = OP(a_odd, b_odd);\
+ return F32ToBf16Bool(op_even, op_odd);\
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmul<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pdiv<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
+ BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(psub<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf prsqrt<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(prsqrt<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
+ return pldexp_generic(a,exponent);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){
+ BF16_TO_F32_BINARY_OP_WRAPPER(pldexp<Packet4f>, a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
+ return pfrexp_generic(a,exponent);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){
+ Packet4f a_even = Bf16ToF32Even(a);
+ Packet4f a_odd = Bf16ToF32Odd(a);
+ Packet4f e_even;
+ Packet4f e_odd;
+ Packet4f op_even = pfrexp<Packet4f>(a_even, e_even);
+ Packet4f op_odd = pfrexp<Packet4f>(a_odd, e_odd);
+ e = F32ToBf16(e_even, e_odd);
+ return F32ToBf16(op_even, op_odd);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psin<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcos<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pcos_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf plog<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(plog_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pfloor<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pceil<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pround<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(print<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
+ Packet4f a_even = Bf16ToF32Even(a);
+ Packet4f a_odd = Bf16ToF32Odd(a);
+ Packet4f b_even = Bf16ToF32Even(b);
+ Packet4f b_odd = Bf16ToF32Odd(b);
+ Packet4f c_even = Bf16ToF32Even(c);
+ Packet4f c_odd = Bf16ToF32Odd(c);
+ Packet4f pmadd_even = pmadd<Packet4f>(a_even, b_even, c_even);
+ Packet4f pmadd_odd = pmadd<Packet4f>(a_odd, b_odd, c_odd);
+ return F32ToBf16(pmadd_even, pmadd_odd);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmax<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt_or_nan<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) {
+ return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst<Packet8us>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from)
+{
+ return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
+ bfloat16 countdown[8] = { bfloat16(0), bfloat16(1), bfloat16(2), bfloat16(3),
+ bfloat16(4), bfloat16(5), bfloat16(6), bfloat16(7) };
+ return padd<Packet8bf>(pset1<Packet8bf>(a), pload<Packet8bf>(countdown));
+}
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
@@ -540,34 +1451,6 @@ template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
return pfirst(sum);
}
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- Packet4f v[4], sum[4];
-
- // It's easier and faster to transpose then add as columns
- // Check: http://www.freevec.org/function/matrix_4x4_transpose_floats for explanation
- // Do the transpose, first set of moves
- v[0] = vec_mergeh(vecs[0], vecs[2]);
- v[1] = vec_mergel(vecs[0], vecs[2]);
- v[2] = vec_mergeh(vecs[1], vecs[3]);
- v[3] = vec_mergel(vecs[1], vecs[3]);
- // Get the resulting vectors
- sum[0] = vec_mergeh(v[0], v[2]);
- sum[1] = vec_mergel(v[0], v[2]);
- sum[2] = vec_mergeh(v[1], v[3]);
- sum[3] = vec_mergel(v[1], v[3]);
-
- // Now do the summation:
- // Lines 0+1
- sum[0] = sum[0] + sum[1];
- // Lines 2+3
- sum[1] = sum[2] + sum[3];
- // Add the results
- sum[0] = sum[0] + sum[1];
-
- return sum[0];
-}
-
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i sum;
@@ -580,32 +1463,69 @@ template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
return pfirst(sum);
}
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a)
{
- Packet4i v[4], sum[4];
+ float redux_even = predux<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = redux_even + redux_odd;
+ return bfloat16(f32_result);
+}
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a)
+{
+ union{
+ Packet v;
+ __UNPACK_TYPE__(Packet) n[8];
+ } vt;
+ vt.v = a;
+
+ EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] };
+ EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] };
+ Packet4i first_half = pload<Packet4i>(first_loader);
+ Packet4i second_half = pload<Packet4i>(second_loader);
+
+ return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_half) + predux(second_half));
+}
+
+template<> EIGEN_STRONG_INLINE short int predux<Packet8s>(const Packet8s& a)
+{
+ return predux_size8<Packet8s>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux<Packet8us>(const Packet8us& a)
+{
+ return predux_size8<Packet8us>(a);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size16(const Packet& a)
+{
+ union{
+ Packet v;
+ __UNPACK_TYPE__(Packet) n[16];
+ } vt;
+ vt.v = a;
+
+ EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] };
+ EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] };
+ EIGEN_ALIGN16 int third_loader[4] = { vt.n[8], vt.n[9], vt.n[10], vt.n[11] };
+ EIGEN_ALIGN16 int fourth_loader[4] = { vt.n[12], vt.n[13], vt.n[14], vt.n[15] };
- // It's easier and faster to transpose then add as columns
- // Check: http://www.freevec.org/function/matrix_4x4_transpose_floats for explanation
- // Do the transpose, first set of moves
- v[0] = vec_mergeh(vecs[0], vecs[2]);
- v[1] = vec_mergel(vecs[0], vecs[2]);
- v[2] = vec_mergeh(vecs[1], vecs[3]);
- v[3] = vec_mergel(vecs[1], vecs[3]);
- // Get the resulting vectors
- sum[0] = vec_mergeh(v[0], v[2]);
- sum[1] = vec_mergel(v[0], v[2]);
- sum[2] = vec_mergeh(v[1], v[3]);
- sum[3] = vec_mergel(v[1], v[3]);
+ Packet4i first_quarter = pload<Packet4i>(first_loader);
+ Packet4i second_quarter = pload<Packet4i>(second_loader);
+ Packet4i third_quarter = pload<Packet4i>(third_loader);
+ Packet4i fourth_quarter = pload<Packet4i>(fourth_loader);
- // Now do the summation:
- // Lines 0+1
- sum[0] = sum[0] + sum[1];
- // Lines 2+3
- sum[1] = sum[2] + sum[3];
- // Add the results
- sum[0] = sum[0] + sum[1];
+ return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_quarter) + predux(second_quarter)
+ + predux(third_quarter) + predux(fourth_quarter));
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux<Packet16c>(const Packet16c& a)
+{
+ return predux_size16<Packet16c>(a);
+}
- return sum[0];
+template<> EIGEN_STRONG_INLINE unsigned char predux<Packet16uc>(const Packet16uc& a)
+{
+ return predux_size16<Packet16uc>(a);
}
// Other reduction functions:
@@ -624,97 +1544,255 @@ template<> EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a)
return aux[0] * aux[1] * aux[2] * aux[3];
}
+template<> EIGEN_STRONG_INLINE short int predux_mul<Packet8s>(const Packet8s& a)
+{
+ Packet8s pair, quad, octo;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_mul<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a)
+{
+ float redux_even = predux_mul<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_mul<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = redux_even * redux_odd;
+ return bfloat16(f32_result);
+}
+
+
+template<> EIGEN_STRONG_INLINE signed char predux_mul<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+ result = vec_mul(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_mul<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+ result = vec_mul(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
// min
-template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+template<typename Packet> EIGEN_STRONG_INLINE
+__UNPACK_TYPE__(Packet) predux_min4(const Packet& a)
{
- Packet4f b, res;
+ Packet b, res;
b = vec_min(a, vec_sld(a, a, 8));
res = vec_min(b, vec_sld(b, b, 4));
return pfirst(res);
}
+
+template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+{
+ return predux_min4<Packet4f>(a);
+}
+
template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
{
- Packet4i b, res;
- b = vec_min(a, vec_sld(a, a, 8));
- res = vec_min(b, vec_sld(b, b, 4));
- return pfirst(res);
+ return predux_min4<Packet4i>(a);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a)
+{
+ float redux_even = predux_min<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_min<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = (std::min)(redux_even, redux_odd);
+ return bfloat16(f32_result);
}
+template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a)
+{
+ Packet8s pair, quad, octo;
+
+ //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) }
+ pair = vec_min(a, vec_sld(a, a, 8));
+
+ //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) }
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_min<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) }
+ pair = vec_min(a, vec_sld(a, a, 8));
+
+ //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) }
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux_min<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_min(a, vec_sld(a, a, 8));
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ result = vec_min(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_min<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_min(a, vec_sld(a, a, 8));
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ result = vec_min(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
// max
-template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_max4(const Packet& a)
{
- Packet4f b, res;
+ Packet b, res;
b = vec_max(a, vec_sld(a, a, 8));
res = vec_max(b, vec_sld(b, b, 4));
return pfirst(res);
}
+template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+{
+ return predux_max4<Packet4f>(a);
+}
+
template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
{
- Packet4i b, res;
- b = vec_max(a, vec_sld(a, a, 8));
- res = vec_max(b, vec_sld(b, b, 4));
- return pfirst(res);
+ return predux_max4<Packet4i>(a);
}
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a)
{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
-#ifdef _BIG_ENDIAN
- switch (Offset % 4) {
- case 1:
- first = vec_sld(first, second, 4); break;
- case 2:
- first = vec_sld(first, second, 8); break;
- case 3:
- first = vec_sld(first, second, 12); break;
- }
-#else
- switch (Offset % 4) {
- case 1:
- first = vec_sld(second, first, 12); break;
- case 2:
- first = vec_sld(second, first, 8); break;
- case 3:
- first = vec_sld(second, first, 4); break;
- }
-#endif
- }
-};
+ float redux_even = predux_max<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_max<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = (std::max)(redux_even, redux_odd);
+ return bfloat16(f32_result);
+}
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
+template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a)
{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
-#ifdef _BIG_ENDIAN
- switch (Offset % 4) {
- case 1:
- first = vec_sld(first, second, 4); break;
- case 2:
- first = vec_sld(first, second, 8); break;
- case 3:
- first = vec_sld(first, second, 12); break;
- }
-#else
- switch (Offset % 4) {
- case 1:
- first = vec_sld(second, first, 12); break;
- case 2:
- first = vec_sld(second, first, 8); break;
- case 3:
- first = vec_sld(second, first, 4); break;
- }
-#endif
- }
-};
+ Packet8s pair, quad, octo;
+
+ //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) }
+ pair = vec_max(a, vec_sld(a, a, 8));
+
+ //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) }
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_max<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) }
+ pair = vec_max(a, vec_sld(a, a, 8));
+
+ //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) }
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux_max<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_max(a, vec_sld(a, a, 8));
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ result = vec_max(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_max<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_max(a, vec_sld(a, a, 8));
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ result = vec_max(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
+{
+ return vec_any_ne(x, pzero(x));
+}
+
+template <typename T> EIGEN_DEVICE_FUNC inline void
+ptranpose_common(PacketBlock<T,4>& kernel){
+ T t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
- Packet4f t0, t1, t2, t3;
+ ptranpose_common<Packet4f>(kernel);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet4i,4>& kernel) {
+ ptranpose_common<Packet4i>(kernel);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8s,4>& kernel) {
+ Packet8s t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
@@ -726,8 +1804,8 @@ ptranspose(PacketBlock<Packet4f,4>& kernel) {
}
EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4i,4>& kernel) {
- Packet4i t0, t1, t2, t3;
+ptranspose(PacketBlock<Packet8us,4>& kernel) {
+ Packet8us t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
@@ -738,18 +1816,440 @@ ptranspose(PacketBlock<Packet4i,4>& kernel) {
kernel.packet[3] = vec_mergel(t1, t3);
}
-template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ Packet8us t0, t1, t2, t3;
+
+ t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val);
+ t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val);
+ t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val);
+ t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16c,4>& kernel) {
+ Packet16c t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16uc,4>& kernel) {
+ Packet16uc t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8s,8>& kernel) {
+ Packet8s v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]);
+ v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]);
+ v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]);
+ v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]);
+ v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]);
+ v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]);
+ v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]);
+ v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]);
+ sum[0] = vec_mergeh(v[0], v[4]);
+ sum[1] = vec_mergel(v[0], v[4]);
+ sum[2] = vec_mergeh(v[1], v[5]);
+ sum[3] = vec_mergel(v[1], v[5]);
+ sum[4] = vec_mergeh(v[2], v[6]);
+ sum[5] = vec_mergel(v[2], v[6]);
+ sum[6] = vec_mergeh(v[3], v[7]);
+ sum[7] = vec_mergel(v[3], v[7]);
+
+ kernel.packet[0] = vec_mergeh(sum[0], sum[4]);
+ kernel.packet[1] = vec_mergel(sum[0], sum[4]);
+ kernel.packet[2] = vec_mergeh(sum[1], sum[5]);
+ kernel.packet[3] = vec_mergel(sum[1], sum[5]);
+ kernel.packet[4] = vec_mergeh(sum[2], sum[6]);
+ kernel.packet[5] = vec_mergel(sum[2], sum[6]);
+ kernel.packet[6] = vec_mergeh(sum[3], sum[7]);
+ kernel.packet[7] = vec_mergel(sum[3], sum[7]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8us,8>& kernel) {
+ Packet8us v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]);
+ v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]);
+ v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]);
+ v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]);
+ v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]);
+ v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]);
+ v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]);
+ v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]);
+ sum[0] = vec_mergeh(v[0], v[4]);
+ sum[1] = vec_mergel(v[0], v[4]);
+ sum[2] = vec_mergeh(v[1], v[5]);
+ sum[3] = vec_mergel(v[1], v[5]);
+ sum[4] = vec_mergeh(v[2], v[6]);
+ sum[5] = vec_mergel(v[2], v[6]);
+ sum[6] = vec_mergeh(v[3], v[7]);
+ sum[7] = vec_mergel(v[3], v[7]);
+
+ kernel.packet[0] = vec_mergeh(sum[0], sum[4]);
+ kernel.packet[1] = vec_mergel(sum[0], sum[4]);
+ kernel.packet[2] = vec_mergeh(sum[1], sum[5]);
+ kernel.packet[3] = vec_mergel(sum[1], sum[5]);
+ kernel.packet[4] = vec_mergeh(sum[2], sum[6]);
+ kernel.packet[5] = vec_mergel(sum[2], sum[6]);
+ kernel.packet[6] = vec_mergeh(sum[3], sum[7]);
+ kernel.packet[7] = vec_mergel(sum[3], sum[7]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ Packet8bf v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val);
+ v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val);
+ v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val);
+ v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val);
+ v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val);
+ v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val);
+ v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val);
+ v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val);
+ sum[0] = vec_mergeh(v[0].m_val, v[4].m_val);
+ sum[1] = vec_mergel(v[0].m_val, v[4].m_val);
+ sum[2] = vec_mergeh(v[1].m_val, v[5].m_val);
+ sum[3] = vec_mergel(v[1].m_val, v[5].m_val);
+ sum[4] = vec_mergeh(v[2].m_val, v[6].m_val);
+ sum[5] = vec_mergel(v[2].m_val, v[6].m_val);
+ sum[6] = vec_mergeh(v[3].m_val, v[7].m_val);
+ sum[7] = vec_mergel(v[3].m_val, v[7].m_val);
+
+ kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val);
+ kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val);
+ kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val);
+ kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val);
+ kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val);
+ kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val);
+ kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val);
+ kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16c,16>& kernel) {
+ Packet16c step1[16], step2[16], step3[16];
+
+ step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]);
+ step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]);
+ step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]);
+ step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]);
+ step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]);
+ step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]);
+ step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]);
+ step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]);
+ step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]);
+ step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]);
+ step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]);
+ step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]);
+ step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]);
+ step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]);
+ step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]);
+ step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]);
+
+ step2[0] = vec_mergeh(step1[0], step1[8]);
+ step2[1] = vec_mergel(step1[0], step1[8]);
+ step2[2] = vec_mergeh(step1[1], step1[9]);
+ step2[3] = vec_mergel(step1[1], step1[9]);
+ step2[4] = vec_mergeh(step1[2], step1[10]);
+ step2[5] = vec_mergel(step1[2], step1[10]);
+ step2[6] = vec_mergeh(step1[3], step1[11]);
+ step2[7] = vec_mergel(step1[3], step1[11]);
+ step2[8] = vec_mergeh(step1[4], step1[12]);
+ step2[9] = vec_mergel(step1[4], step1[12]);
+ step2[10] = vec_mergeh(step1[5], step1[13]);
+ step2[11] = vec_mergel(step1[5], step1[13]);
+ step2[12] = vec_mergeh(step1[6], step1[14]);
+ step2[13] = vec_mergel(step1[6], step1[14]);
+ step2[14] = vec_mergeh(step1[7], step1[15]);
+ step2[15] = vec_mergel(step1[7], step1[15]);
+
+ step3[0] = vec_mergeh(step2[0], step2[8]);
+ step3[1] = vec_mergel(step2[0], step2[8]);
+ step3[2] = vec_mergeh(step2[1], step2[9]);
+ step3[3] = vec_mergel(step2[1], step2[9]);
+ step3[4] = vec_mergeh(step2[2], step2[10]);
+ step3[5] = vec_mergel(step2[2], step2[10]);
+ step3[6] = vec_mergeh(step2[3], step2[11]);
+ step3[7] = vec_mergel(step2[3], step2[11]);
+ step3[8] = vec_mergeh(step2[4], step2[12]);
+ step3[9] = vec_mergel(step2[4], step2[12]);
+ step3[10] = vec_mergeh(step2[5], step2[13]);
+ step3[11] = vec_mergel(step2[5], step2[13]);
+ step3[12] = vec_mergeh(step2[6], step2[14]);
+ step3[13] = vec_mergel(step2[6], step2[14]);
+ step3[14] = vec_mergeh(step2[7], step2[15]);
+ step3[15] = vec_mergel(step2[7], step2[15]);
+
+ kernel.packet[0] = vec_mergeh(step3[0], step3[8]);
+ kernel.packet[1] = vec_mergel(step3[0], step3[8]);
+ kernel.packet[2] = vec_mergeh(step3[1], step3[9]);
+ kernel.packet[3] = vec_mergel(step3[1], step3[9]);
+ kernel.packet[4] = vec_mergeh(step3[2], step3[10]);
+ kernel.packet[5] = vec_mergel(step3[2], step3[10]);
+ kernel.packet[6] = vec_mergeh(step3[3], step3[11]);
+ kernel.packet[7] = vec_mergel(step3[3], step3[11]);
+ kernel.packet[8] = vec_mergeh(step3[4], step3[12]);
+ kernel.packet[9] = vec_mergel(step3[4], step3[12]);
+ kernel.packet[10] = vec_mergeh(step3[5], step3[13]);
+ kernel.packet[11] = vec_mergel(step3[5], step3[13]);
+ kernel.packet[12] = vec_mergeh(step3[6], step3[14]);
+ kernel.packet[13] = vec_mergel(step3[6], step3[14]);
+ kernel.packet[14] = vec_mergeh(step3[7], step3[15]);
+ kernel.packet[15] = vec_mergel(step3[7], step3[15]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16uc,16>& kernel) {
+ Packet16uc step1[16], step2[16], step3[16];
+
+ step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]);
+ step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]);
+ step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]);
+ step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]);
+ step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]);
+ step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]);
+ step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]);
+ step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]);
+ step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]);
+ step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]);
+ step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]);
+ step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]);
+ step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]);
+ step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]);
+ step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]);
+ step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]);
+
+ step2[0] = vec_mergeh(step1[0], step1[8]);
+ step2[1] = vec_mergel(step1[0], step1[8]);
+ step2[2] = vec_mergeh(step1[1], step1[9]);
+ step2[3] = vec_mergel(step1[1], step1[9]);
+ step2[4] = vec_mergeh(step1[2], step1[10]);
+ step2[5] = vec_mergel(step1[2], step1[10]);
+ step2[6] = vec_mergeh(step1[3], step1[11]);
+ step2[7] = vec_mergel(step1[3], step1[11]);
+ step2[8] = vec_mergeh(step1[4], step1[12]);
+ step2[9] = vec_mergel(step1[4], step1[12]);
+ step2[10] = vec_mergeh(step1[5], step1[13]);
+ step2[11] = vec_mergel(step1[5], step1[13]);
+ step2[12] = vec_mergeh(step1[6], step1[14]);
+ step2[13] = vec_mergel(step1[6], step1[14]);
+ step2[14] = vec_mergeh(step1[7], step1[15]);
+ step2[15] = vec_mergel(step1[7], step1[15]);
+
+ step3[0] = vec_mergeh(step2[0], step2[8]);
+ step3[1] = vec_mergel(step2[0], step2[8]);
+ step3[2] = vec_mergeh(step2[1], step2[9]);
+ step3[3] = vec_mergel(step2[1], step2[9]);
+ step3[4] = vec_mergeh(step2[2], step2[10]);
+ step3[5] = vec_mergel(step2[2], step2[10]);
+ step3[6] = vec_mergeh(step2[3], step2[11]);
+ step3[7] = vec_mergel(step2[3], step2[11]);
+ step3[8] = vec_mergeh(step2[4], step2[12]);
+ step3[9] = vec_mergel(step2[4], step2[12]);
+ step3[10] = vec_mergeh(step2[5], step2[13]);
+ step3[11] = vec_mergel(step2[5], step2[13]);
+ step3[12] = vec_mergeh(step2[6], step2[14]);
+ step3[13] = vec_mergel(step2[6], step2[14]);
+ step3[14] = vec_mergeh(step2[7], step2[15]);
+ step3[15] = vec_mergel(step2[7], step2[15]);
+
+ kernel.packet[0] = vec_mergeh(step3[0], step3[8]);
+ kernel.packet[1] = vec_mergel(step3[0], step3[8]);
+ kernel.packet[2] = vec_mergeh(step3[1], step3[9]);
+ kernel.packet[3] = vec_mergel(step3[1], step3[9]);
+ kernel.packet[4] = vec_mergeh(step3[2], step3[10]);
+ kernel.packet[5] = vec_mergel(step3[2], step3[10]);
+ kernel.packet[6] = vec_mergeh(step3[3], step3[11]);
+ kernel.packet[7] = vec_mergel(step3[3], step3[11]);
+ kernel.packet[8] = vec_mergeh(step3[4], step3[12]);
+ kernel.packet[9] = vec_mergel(step3[4], step3[12]);
+ kernel.packet[10] = vec_mergeh(step3[5], step3[13]);
+ kernel.packet[11] = vec_mergel(step3[5], step3[13]);
+ kernel.packet[12] = vec_mergeh(step3[6], step3[14]);
+ kernel.packet[13] = vec_mergel(step3[6], step3[14]);
+ kernel.packet[14] = vec_mergeh(step3[7], step3[15]);
+ kernel.packet[15] = vec_mergel(step3[7], step3[15]);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE
+Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) {
Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
return vec_sel(elsePacket, thenPacket, mask);
}
+template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
+ return pblend4<Packet4i>(ifPacket, thenPacket, elsePacket);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) {
- Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
- Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
+ return pblend4<Packet4f>(ifPacket, thenPacket, elsePacket);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) {
+ Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(select, p8us_ONE));
+ Packet8s result = vec_sel(elsePacket, thenPacket, mask);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) {
+ Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(reinterpret_cast<Packet8us>(select), p8us_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) {
+ return pblend<Packet8us>(ifPacket, thenPacket, elsePacket);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) {
+ Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
+ ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
+ ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, const Packet16uc& thenPacket, const Packet16uc& elsePacket) {
+ Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
+ ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
+ ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
return vec_sel(elsePacket, thenPacket, mask);
}
+template <>
+struct type_casting_traits<float, int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<int, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<bfloat16, unsigned short int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<unsigned short int, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return vec_cts(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
+ return vec_ctu(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
+ Packet4f float_even = Bf16ToF32Even(a);
+ Packet4f float_odd = Bf16ToF32Odd(a);
+ Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
+ Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
+ Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
+
+ //Check values that are bigger than USHRT_MAX (0xFFFF)
+ Packet4bi overflow_selector;
+ if(vec_any_gt(int_even, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
+ low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+ if(vec_any_gt(int_odd, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
+ low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+
+ low_odd = plogical_shift_left<16>(low_odd);
+
+ Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
+ return reinterpret_cast<Packet8us>(int_final);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
+ //short -> int -> float -> bfloat16
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
+ Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
+ Packet4ui int_odd = plogical_shift_right<16>(int_cast);
+ Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
+ Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
+ return F32ToBf16(float_even, float_odd);
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
+ return reinterpret_cast<Packet4i>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
+ return reinterpret_cast<Packet4f>(a);
+}
+
+
//---------- double ----------
#ifdef __VSX__
@@ -764,9 +2264,12 @@ typedef __vector __bool long Packet2bl;
static Packet2l p2l_ONE = { 1, 1 };
static Packet2l p2l_ZERO = reinterpret_cast<Packet2l>(p4i_ZERO);
-static Packet2d p2d_ONE = { 1.0, 1.0 };
+static Packet2ul p2ul_SIGN = { 0x8000000000000000ull, 0x8000000000000000ull };
+static Packet2ul p2ul_PREV0DOT5 = { 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull };
+static Packet2d p2d_ONE = { 1.0, 1.0 };
static Packet2d p2d_ZERO = reinterpret_cast<Packet2d>(p4f_ZERO);
-static Packet2d p2d_MZERO = { -0.0, -0.0 };
+static Packet2d p2d_MZERO = { numext::bit_cast<double>(0x8000000000000000ull),
+ numext::bit_cast<double>(0x8000000000000000ull) };
#ifdef _BIG_ENDIAN
static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(p2d_ZERO), reinterpret_cast<Packet4f>(p2d_ONE), 8));
@@ -774,16 +2277,9 @@ static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_c
static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(p2d_ONE), reinterpret_cast<Packet4f>(p2d_ZERO), 8));
#endif
-template<int index> Packet2d vec_splat_dbl(Packet2d& a);
-
-template<> EIGEN_STRONG_INLINE Packet2d vec_splat_dbl<0>(Packet2d& a)
-{
- return reinterpret_cast<Packet2d>(vec_perm(a, a, p16uc_PSET64_HI));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d vec_splat_dbl<1>(Packet2d& a)
+template<int index> Packet2d vec_splat_dbl(Packet2d& a)
{
- return reinterpret_cast<Packet2d>(vec_perm(a, a, p16uc_PSET64_LO));
+ return vec_splat(a, index);
}
template<> struct packet_traits<double> : default_packet_traits
@@ -812,12 +2308,13 @@ template<> struct packet_traits<double> : default_packet_traits
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasNegate = 1,
HasBlend = 1
};
};
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2d half; };
inline std::ostream & operator <<(std::ostream & s, const Packet2l & v)
{
@@ -845,21 +2342,13 @@ inline std::ostream & operator <<(std::ostream & s, const Packet2d & v)
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from)
{
EIGEN_DEBUG_ALIGNED_LOAD
-#ifdef __VSX__
- return vec_vsx_ld(0, from);
-#else
- return vec_ld(0, from);
-#endif
+ return vec_xl(0, const_cast<double *>(from)); // cast needed by Clang
}
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from)
{
EIGEN_DEBUG_ALIGNED_STORE
-#ifdef __VSX__
- vec_vsx_st(from, 0, to);
-#else
- vec_st(from, 0, to);
-#endif
+ vec_xst(from, 0, to);
}
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
@@ -867,28 +2356,32 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
return v;
}
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from) {
+ Packet2l v = {static_cast<long long>(from), static_cast<long long>(from)};
+ return reinterpret_cast<Packet2d>(v);
+}
+
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet2d>(const double *a,
Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
- a1 = pload<Packet2d>(a);
- a0 = vec_splat_dbl<0>(a1);
- a1 = vec_splat_dbl<1>(a1);
- a3 = pload<Packet2d>(a+2);
- a2 = vec_splat_dbl<0>(a3);
- a3 = vec_splat_dbl<1>(a3);
+ //This way is faster than vec_splat (at least for doubles in Power 9)
+ a0 = pset1<Packet2d>(a[0]);
+ a1 = pset1<Packet2d>(a[1]);
+ a2 = pset1<Packet2d>(a[2]);
+ a3 = pset1<Packet2d>(a[3]);
}
template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride)
{
- double EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 double af[2];
af[0] = from[0*stride];
af[1] = from[1*stride];
return pload<Packet2d>(af);
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
{
- double EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 double af[2];
pstore<double>(af, from);
to[0*stride] = af[0];
to[1*stride] = af[1];
@@ -910,9 +2403,29 @@ template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); }
-template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+ // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+ Packet2d ret;
+ __asm__ ("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ }
-template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+ // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
+ Packet2d ret;
+ __asm__ ("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) {
+ Packet2d c = reinterpret_cast<Packet2d>(vec_cmpge(a,b));
+ return vec_nor(c,c);
+}
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); }
@@ -922,14 +2435,34 @@ template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const
template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); }
-template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a)
+{
+ Packet2d t = vec_add(reinterpret_cast<Packet2d>(vec_or(vec_and(reinterpret_cast<Packet2ul>(a), p2ul_SIGN), p2ul_PREV0DOT5)), a);
+ Packet2d res;
+
+ __asm__("xvrdpiz %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (t));
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return vec_ceil(a); }
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return vec_floor(a); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a)
+{
+ Packet2d res;
+
+ __asm__("xvrdpic %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (a));
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
- return (Packet2d) vec_vsx_ld((long)from & 15, (const double*) _EIGEN_ALIGNED_PTR(from));
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return vec_xl(0, const_cast<double*>(from));
}
template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
@@ -942,13 +2475,13 @@ template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st((Packet4f)from, (long)to & 15, (float*) _EIGEN_ALIGNED_PTR(to));
+ EIGEN_DEBUG_UNALIGNED_STORE
+ vec_xst(from, 0, to);
}
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_PPC_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { double EIGEN_ALIGN16 x[2]; pstore<double>(x, a); return x[0]; }
+template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { EIGEN_ALIGN16 double x[2]; pstore<double>(x, a); return x[0]; }
template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
{
@@ -956,6 +2489,177 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
+// VSX support varies between different compilers and even different
+// versions of the same compiler. For gcc version >= 4.9.3, we can use
+// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
+// a slow version that works with older compilers.
+// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
+// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
+template<>
+inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
+#if EIGEN_GNUC_AT_LEAST(5, 4) || \
+ (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
+ return vec_cts(x, 0); // TODO: check clang version.
+#else
+ double tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2l l = { static_cast<long long>(tmp[0]),
+ static_cast<long long>(tmp[1]) };
+ return l;
+#endif
+}
+
+template<>
+inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
+ unsigned long long tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2d d = { static_cast<double>(tmp[0]),
+ static_cast<double>(tmp[1]) };
+ return d;
+}
+
+
+// Packet2l shifts.
+// For POWER8 we simply use vec_sr/l.
+//
+// Things are more complicated for POWER7. There is actually a
+// vec_xxsxdi intrinsic but it is not supported by some gcc versions.
+// So we need to shift by N % 32 and rearrage bytes.
+#ifdef __POWER8_VECTOR__
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) {
+ const Packet2ul shift = { N, N };
+ return vec_sl(a, shift);
+}
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
+ const Packet2ul shift = { N, N };
+ return vec_sr(a, shift);
+}
+
+#else
+
+// Shifts [A, B, C, D] to [B, 0, D, 0].
+// Used to implement left shifts for Packet2l.
+EIGEN_ALWAYS_INLINE Packet4i shift_even_left(const Packet4i& a) {
+ static const Packet16uc perm = {
+ 0x14, 0x15, 0x16, 0x17, 0x00, 0x01, 0x02, 0x03,
+ 0x1c, 0x1d, 0x1e, 0x1f, 0x08, 0x09, 0x0a, 0x0b };
+ #ifdef _BIG_ENDIAN
+ return vec_perm(p4i_ZERO, a, perm);
+ #else
+ return vec_perm(a, p4i_ZERO, perm);
+ #endif
+}
+
+// Shifts [A, B, C, D] to [0, A, 0, C].
+// Used to implement right shifts for Packet2l.
+EIGEN_ALWAYS_INLINE Packet4i shift_odd_right(const Packet4i& a) {
+ static const Packet16uc perm = {
+ 0x04, 0x05, 0x06, 0x07, 0x10, 0x11, 0x12, 0x13,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x18, 0x19, 0x1a, 0x1b };
+ #ifdef _BIG_ENDIAN
+ return vec_perm(p4i_ZERO, a, perm);
+ #else
+ return vec_perm(a, p4i_ZERO, perm);
+ #endif
+}
+
+template<int N, typename EnableIf = void>
+struct plogical_shift_left_impl;
+
+template<int N>
+struct plogical_shift_left_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned n = static_cast<unsigned>(N);
+ const Packet4ui shift = {n, n, n, n};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ static const unsigned m = static_cast<unsigned>(32 - N);
+ const Packet4ui shift_right = {m, m, m, m};
+ const Packet4i out_hi = vec_sl(ai, shift);
+ const Packet4i out_lo = shift_even_left(vec_sr(ai, shift_right));
+ return reinterpret_cast<Packet2l>(por<Packet4i>(out_hi, out_lo));
+ }
+};
+
+template<int N>
+struct plogical_shift_left_impl<N, typename enable_if<(N >= 32)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned m = static_cast<unsigned>(N - 32);
+ const Packet4ui shift = {m, m, m, m};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ return reinterpret_cast<Packet2l>(shift_even_left(vec_sl(ai, shift)));
+ }
+};
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) {
+ return plogical_shift_left_impl<N>::run(a);
+}
+
+template<int N, typename EnableIf = void>
+struct plogical_shift_right_impl;
+
+template<int N>
+struct plogical_shift_right_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned n = static_cast<unsigned>(N);
+ const Packet4ui shift = {n, n, n, n};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ static const unsigned m = static_cast<unsigned>(32 - N);
+ const Packet4ui shift_left = {m, m, m, m};
+ const Packet4i out_lo = vec_sr(ai, shift);
+ const Packet4i out_hi = shift_odd_right(vec_sl(ai, shift_left));
+ return reinterpret_cast<Packet2l>(por<Packet4i>(out_hi, out_lo));
+ }
+};
+
+template<int N>
+struct plogical_shift_right_impl<N, typename enable_if<(N >= 32)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned m = static_cast<unsigned>(N - 32);
+ const Packet4ui shift = {m, m, m, m};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ return reinterpret_cast<Packet2l>(shift_odd_right(vec_sr(ai, shift)));
+ }
+};
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
+ return plogical_shift_right_impl<N>::run(a);
+}
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet2d max_exponent = pset1<Packet2d>(2099.0);
+ const Packet2l e = pcast<Packet2d, Packet2l>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+
+ // Split 2^e into four factors and multiply:
+ const Packet2l bias = { 1023, 1023 };
+ Packet2l b = plogical_shift_right<2>(e); // floor(e/4)
+ Packet2d c = reinterpret_cast<Packet2d>(plogical_shift_left<52>(b + bias));
+ Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = reinterpret_cast<Packet2d>(plogical_shift_left<52>(b + bias)); // 2^(e - 3b)
+ out = pmul(out, c); // a * 2^e
+ return out;
+}
+
+
+// Extract exponent without existence of Packet2l.
+template<>
+EIGEN_STRONG_INLINE
+Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
+ return pcast<Packet2l, Packet2d>(plogical_shift_right<52>(reinterpret_cast<Packet2l>(pabs(a))));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d> (const Packet2d& a, Packet2d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
{
Packet2d b, sum;
@@ -964,20 +2668,6 @@ template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
return pfirst<Packet2d>(sum);
}
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- Packet2d v[2], sum;
- v[0] = vecs[0] + reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(vecs[0]), reinterpret_cast<Packet4f>(vecs[0]), 8));
- v[1] = vecs[1] + reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(vecs[1]), reinterpret_cast<Packet4f>(vecs[1]), 8));
-
-#ifdef _BIG_ENDIAN
- sum = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(v[0]), reinterpret_cast<Packet4f>(v[1]), 8));
-#else
- sum = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(v[1]), reinterpret_cast<Packet4f>(v[0]), 8));
-#endif
-
- return sum;
-}
// Other reduction functions:
// mul
template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
@@ -997,20 +2687,6 @@ template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
return pfirst(pmax(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(a), reinterpret_cast<Packet4ui>(a), 8))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset == 1)
-#ifdef _BIG_ENDIAN
- first = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(first), reinterpret_cast<Packet4ui>(second), 8));
-#else
- first = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(second), reinterpret_cast<Packet4ui>(first), 8));
-#endif
- }
-};
-
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2d,2>& kernel) {
Packet2d t0, t1;
@@ -1022,9 +2698,11 @@ ptranspose(PacketBlock<Packet2d,2>& kernel) {
template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
Packet2l select = { ifPacket.select[0], ifPacket.select[1] };
- Packet2bl mask = vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE));
+ Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
return vec_sel(elsePacket, thenPacket, mask);
}
+
+
#endif // __VSX__
} // end namespace internal
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index 9c2536509..deb4c8694 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2021 C. Antonio Sanchez <cantonios@google.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -11,93 +12,247 @@
#define EIGEN_COMPLEX_CUDA_H
// clang-format off
+// Many std::complex methods such as operator+, operator-, operator* and
+// operator/ are not constexpr. Due to this, GCC and older versions of clang do
+// not treat them as device functions and thus Eigen functors making use of
+// these operators fail to compile. Here, we manually specialize these
+// operators and functors for complex types when building for CUDA to enable
+// their use on-device.
+
+#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE)
+
+// ICC already specializes std::complex<float> and std::complex<double>
+// operators, preventing us from making them device functions here.
+// This will lead to silent runtime errors if the operators are used on device.
+//
+// To allow std::complex operator use on device, define _OVERRIDE_COMPLEX_SPECIALIZATION_
+// prior to first inclusion of <complex>. This prevents ICC from adding
+// its own specializations, so our custom ones below can be used instead.
+#if !(defined(EIGEN_COMP_ICC) && defined(_USE_COMPLEX_SPECIALIZATION_))
+
+// Import Eigen's internal operator specializations.
+#define EIGEN_USING_STD_COMPLEX_OPERATORS \
+ using Eigen::complex_operator_detail::operator+; \
+ using Eigen::complex_operator_detail::operator-; \
+ using Eigen::complex_operator_detail::operator*; \
+ using Eigen::complex_operator_detail::operator/; \
+ using Eigen::complex_operator_detail::operator+=; \
+ using Eigen::complex_operator_detail::operator-=; \
+ using Eigen::complex_operator_detail::operator*=; \
+ using Eigen::complex_operator_detail::operator/=; \
+ using Eigen::complex_operator_detail::operator==; \
+ using Eigen::complex_operator_detail::operator!=;
namespace Eigen {
-namespace internal {
+// Specialized std::complex overloads.
+namespace complex_operator_detail {
-#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+std::complex<T> complex_multiply(const std::complex<T>& a, const std::complex<T>& b) {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
+ const T b_real = numext::real(b);
+ const T b_imag = numext::imag(b);
+ return std::complex<T>(
+ a_real * b_real - a_imag * b_imag,
+ a_imag * b_real + a_real * b_imag);
+}
-// Many std::complex methods such as operator+, operator-, operator* and
-// operator/ are not constexpr. Due to this, clang does not treat them as device
-// functions and thus Eigen functors making use of these operators fail to
-// compile. Here, we manually specialize these functors for complex types when
-// building for CUDA to avoid non-constexpr methods.
-
-// Sum
-template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
- typedef typename std::complex<T> result_type;
-
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
- return std::complex<T>(numext::real(a) + numext::real(b),
- numext::imag(a) + numext::imag(b));
- }
-};
-
-template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {};
-
-
-// Difference
-template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
- typedef typename std::complex<T> result_type;
-
- EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
- return std::complex<T>(numext::real(a) - numext::real(b),
- numext::imag(a) - numext::imag(b));
- }
-};
-
-template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {};
-
-
-// Product
-template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
- enum {
- Vectorizable = packet_traits<std::complex<T>>::HasMul
- };
- typedef typename std::complex<T> result_type;
-
- EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
- const T a_real = numext::real(a);
- const T a_imag = numext::imag(a);
- const T b_real = numext::real(b);
- const T b_imag = numext::imag(b);
- return std::complex<T>(a_real * b_real - a_imag * b_imag,
- a_real * b_imag + a_imag * b_real);
- }
-};
-
-template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {};
-
-
-// Quotient
-template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
- enum {
- Vectorizable = packet_traits<std::complex<T>>::HasDiv
- };
- typedef typename std::complex<T> result_type;
-
- EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
- const T a_real = numext::real(a);
- const T a_imag = numext::imag(a);
- const T b_real = numext::real(b);
- const T b_imag = numext::imag(b);
- const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
- return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
- (a_imag * b_real - a_real * b_imag) * norm);
- }
-};
-
-template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+std::complex<T> complex_divide_fast(const std::complex<T>& a, const std::complex<T>& b) {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
+ const T b_real = numext::real(b);
+ const T b_imag = numext::imag(b);
+ const T norm = (b_real * b_real + b_imag * b_imag);
+ return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm,
+ (a_imag * b_real - a_real * b_imag) / norm);
+}
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+std::complex<T> complex_divide_stable(const std::complex<T>& a, const std::complex<T>& b) {
+ const T a_real = numext::real(a);
+ const T a_imag = numext::imag(a);
+ const T b_real = numext::real(b);
+ const T b_imag = numext::imag(b);
+ // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
+ // guards against over/under-flow.
+ const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
+ const T rscale = scale_imag ? T(1) : b_real / b_imag;
+ const T iscale = scale_imag ? b_imag / b_real : T(1);
+ const T denominator = b_real * rscale + b_imag * iscale;
+ return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
+ (a_imag * rscale - a_real * iscale) / denominator);
+}
+
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+std::complex<T> complex_divide(const std::complex<T>& a, const std::complex<T>& b) {
+#if EIGEN_FAST_MATH
+ return complex_divide_fast(a, b);
+#else
+ return complex_divide_stable(a, b);
#endif
+}
+
+// NOTE: We cannot specialize compound assignment operators with Scalar T,
+// (i.e. operator@=(const T&), for @=+,-,*,/)
+// since they are already specialized for float/double/long double within
+// the standard <complex> header. We also do not specialize the stream
+// operators.
+#define EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(T) \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator+(const std::complex<T>& a) { return a; } \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator-(const std::complex<T>& a) { \
+ return std::complex<T>(-numext::real(a), -numext::imag(a)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator+(const std::complex<T>& a, const std::complex<T>& b) { \
+ return std::complex<T>(numext::real(a) + numext::real(b), numext::imag(a) + numext::imag(b)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator+(const std::complex<T>& a, const T& b) { \
+ return std::complex<T>(numext::real(a) + b, numext::imag(a)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator+(const T& a, const std::complex<T>& b) { \
+ return std::complex<T>(a + numext::real(b), numext::imag(b)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator-(const std::complex<T>& a, const std::complex<T>& b) { \
+ return std::complex<T>(numext::real(a) - numext::real(b), numext::imag(a) - numext::imag(b)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator-(const std::complex<T>& a, const T& b) { \
+ return std::complex<T>(numext::real(a) - b, numext::imag(a)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator-(const T& a, const std::complex<T>& b) { \
+ return std::complex<T>(a - numext::real(b), -numext::imag(b)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator*(const std::complex<T>& a, const std::complex<T>& b) { \
+ return complex_multiply(a, b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator*(const std::complex<T>& a, const T& b) { \
+ return std::complex<T>(numext::real(a) * b, numext::imag(a) * b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator*(const T& a, const std::complex<T>& b) { \
+ return std::complex<T>(a * numext::real(b), a * numext::imag(b)); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator/(const std::complex<T>& a, const std::complex<T>& b) { \
+ return complex_divide(a, b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator/(const std::complex<T>& a, const T& b) { \
+ return std::complex<T>(numext::real(a) / b, numext::imag(a) / b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T> operator/(const T& a, const std::complex<T>& b) { \
+ return complex_divide(std::complex<T>(a, 0), b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T>& operator+=(std::complex<T>& a, const std::complex<T>& b) { \
+ numext::real_ref(a) += numext::real(b); \
+ numext::imag_ref(a) += numext::imag(b); \
+ return a; \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T>& operator-=(std::complex<T>& a, const std::complex<T>& b) { \
+ numext::real_ref(a) -= numext::real(b); \
+ numext::imag_ref(a) -= numext::imag(b); \
+ return a; \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T>& operator*=(std::complex<T>& a, const std::complex<T>& b) { \
+ a = complex_multiply(a, b); \
+ return a; \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+std::complex<T>& operator/=(std::complex<T>& a, const std::complex<T>& b) { \
+ a = complex_divide(a, b); \
+ return a; \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator==(const std::complex<T>& a, const std::complex<T>& b) { \
+ return numext::real(a) == numext::real(b) && numext::imag(a) == numext::imag(b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator==(const std::complex<T>& a, const T& b) { \
+ return numext::real(a) == b && numext::imag(a) == 0; \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator==(const T& a, const std::complex<T>& b) { \
+ return a == numext::real(b) && 0 == numext::imag(b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator!=(const std::complex<T>& a, const std::complex<T>& b) { \
+ return !(a == b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator!=(const std::complex<T>& a, const T& b) { \
+ return !(a == b); \
+} \
+ \
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
+bool operator!=(const T& a, const std::complex<T>& b) { \
+ return !(a == b); \
+}
+
+// Do not specialize for long double, since that reduces to double on device.
+EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(float)
+EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(double)
+
+#undef EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS
+
+
+} // namespace complex_operator_detail
+
+EIGEN_USING_STD_COMPLEX_OPERATORS
+
+namespace numext {
+EIGEN_USING_STD_COMPLEX_OPERATORS
+} // namespace numext
+
+namespace internal {
+EIGEN_USING_STD_COMPLEX_OPERATORS
+
+} // namespace internal
+} // namespace Eigen
-} // end namespace internal
+#endif // !(EIGEN_COMP_ICC && _USE_COMPLEX_SPECIALIZATION_)
-} // end namespace Eigen
+#endif // EIGEN_CUDACC && EIGEN_GPU_COMPILE_PHASE
-#endif // EIGEN_COMPLEX_CUDA_H
+#endif // EIGEN_COMPLEX_CUDA_H
diff --git a/Eigen/src/Core/arch/CUDA/Half.h b/Eigen/src/Core/arch/CUDA/Half.h
deleted file mode 100644
index 88dd385a5..000000000
--- a/Eigen/src/Core/arch/CUDA/Half.h
+++ /dev/null
@@ -1,636 +0,0 @@
-// This file is part of Eigen, a lightweight C++ template library
-// for linear algebra.
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-//
-// The conversion routines are Copyright (c) Fabian Giesen, 2016.
-// The original license follows:
-//
-// Copyright (c) Fabian Giesen, 2016
-// All rights reserved.
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted.
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-// Standard 16-bit float type, mostly useful for GPUs. Defines a new
-// type Eigen::half (inheriting from CUDA's __half struct) with
-// operator overloads such that it behaves basically as an arithmetic
-// type. It will be quite slow on CPUs (so it is recommended to stay
-// in fp32 for CPUs, except for simple parameter conversions, I/O
-// to disk and the likes), but fast on GPUs.
-
-
-#ifndef EIGEN_HALF_CUDA_H
-#define EIGEN_HALF_CUDA_H
-
-#if __cplusplus > 199711L
-#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
-#else
-#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
-#endif
-
-#include <sstream>
-
-namespace Eigen {
-
-struct half;
-
-namespace half_impl {
-
-#if !defined(EIGEN_HAS_CUDA_FP16)
-
-// Make our own __half definition that is similar to CUDA's.
-struct __half {
- EIGEN_DEVICE_FUNC __half() {}
- explicit EIGEN_DEVICE_FUNC __half(unsigned short raw) : x(raw) {}
- unsigned short x;
-};
-
-#endif
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x);
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff);
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h);
-
-struct half_base : public __half {
- EIGEN_DEVICE_FUNC half_base() {}
- EIGEN_DEVICE_FUNC half_base(const half_base& h) : __half(h) {}
- EIGEN_DEVICE_FUNC half_base(const __half& h) : __half(h) {}
-};
-
-} // namespace half_impl
-
-// Class definition.
-struct half : public half_impl::half_base {
- #if !defined(EIGEN_HAS_CUDA_FP16)
- typedef half_impl::__half __half;
- #endif
-
- EIGEN_DEVICE_FUNC half() {}
-
- EIGEN_DEVICE_FUNC half(const __half& h) : half_impl::half_base(h) {}
- EIGEN_DEVICE_FUNC half(const half& h) : half_impl::half_base(h) {}
-
- explicit EIGEN_DEVICE_FUNC half(bool b)
- : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
- template<class T>
- explicit EIGEN_DEVICE_FUNC half(const T& val)
- : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
- explicit EIGEN_DEVICE_FUNC half(float f)
- : half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
- // +0.0 and -0.0 become false, everything else becomes true.
- return (x & 0x7fff) != 0;
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
- return static_cast<signed char>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const {
- return static_cast<unsigned char>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
- return static_cast<short>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
- return static_cast<unsigned short>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
- return static_cast<int>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const {
- return static_cast<unsigned int>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const {
- return static_cast<long>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const {
- return static_cast<unsigned long>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const {
- return static_cast<long long>(half_impl::half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
- return static_cast<unsigned long long>(half_to_float(*this));
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
- return half_impl::half_to_float(*this);
- }
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
- return static_cast<double>(half_impl::half_to_float(*this));
- }
-
- EIGEN_DEVICE_FUNC half& operator=(const half& other) {
- x = other.x;
- return *this;
- }
-};
-
-namespace half_impl {
-
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
-
-// Intrinsics for native fp16 support. Note that on current hardware,
-// these are no faster than fp32 arithmetic (you need to use the half2
-// versions to get the ALU speed increased), but you do save the
-// conversion steps back and forth.
-
-__device__ half operator + (const half& a, const half& b) {
- return __hadd(a, b);
-}
-__device__ half operator * (const half& a, const half& b) {
- return __hmul(a, b);
-}
-__device__ half operator - (const half& a, const half& b) {
- return __hsub(a, b);
-}
-__device__ half operator / (const half& a, const half& b) {
- float num = __half2float(a);
- float denom = __half2float(b);
- return __float2half(num / denom);
-}
-__device__ half operator - (const half& a) {
- return __hneg(a);
-}
-__device__ half& operator += (half& a, const half& b) {
- a = a + b;
- return a;
-}
-__device__ half& operator *= (half& a, const half& b) {
- a = a * b;
- return a;
-}
-__device__ half& operator -= (half& a, const half& b) {
- a = a - b;
- return a;
-}
-__device__ half& operator /= (half& a, const half& b) {
- a = a / b;
- return a;
-}
-__device__ bool operator == (const half& a, const half& b) {
- return __heq(a, b);
-}
-__device__ bool operator != (const half& a, const half& b) {
- return __hne(a, b);
-}
-__device__ bool operator < (const half& a, const half& b) {
- return __hlt(a, b);
-}
-__device__ bool operator <= (const half& a, const half& b) {
- return __hle(a, b);
-}
-__device__ bool operator > (const half& a, const half& b) {
- return __hgt(a, b);
-}
-__device__ bool operator >= (const half& a, const half& b) {
- return __hge(a, b);
-}
-
-#else // Emulate support for half floats
-
-// Definitions for CPUs and older CUDA, mostly working through conversion
-// to/from fp32.
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
- return half(float(a) + float(b));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
- return half(float(a) * float(b));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
- return half(float(a) - float(b));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
- return half(float(a) / float(b));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
- half result;
- result.x = a.x ^ 0x8000;
- return result;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
- a = half(float(a) + float(b));
- return a;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
- a = half(float(a) * float(b));
- return a;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
- a = half(float(a) - float(b));
- return a;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
- a = half(float(a) / float(b));
- return a;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
- return float(a) == float(b);
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
- return float(a) != float(b);
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
- return float(a) < float(b);
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
- return float(a) <= float(b);
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
- return float(a) > float(b);
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
- return float(a) >= float(b);
-}
-
-#endif // Emulate support for half floats
-
-// Division by an index. Do it in full float precision to avoid accuracy
-// issues in converting the denominator to half.
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
- return half(static_cast<float>(a) / static_cast<float>(b));
-}
-
-// Conversion routines, including fallbacks for the host or older CUDA.
-// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of
-// these in hardware. If we need more performance on older/other CPUs, they are
-// also possible to vectorize directly.
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x) {
- __half h;
- h.x = x;
- return h;
-}
-
-union FP32 {
- unsigned int u;
- float f;
-};
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
- return __float2half(ff);
-
-#elif defined(EIGEN_HAS_FP16_C)
- __half h;
- h.x = _cvtss_sh(ff, 0);
- return h;
-
-#else
- FP32 f; f.f = ff;
-
- const FP32 f32infty = { 255 << 23 };
- const FP32 f16max = { (127 + 16) << 23 };
- const FP32 denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
- unsigned int sign_mask = 0x80000000u;
- __half o;
- o.x = static_cast<unsigned short>(0x0u);
-
- unsigned int sign = f.u & sign_mask;
- f.u ^= sign;
-
- // NOTE all the integer compares in this function can be safely
- // compiled into signed compares since all operands are below
- // 0x80000000. Important if you want fast straight SSE2 code
- // (since there's no unsigned PCMPGTD).
-
- if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
- o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
- } else { // (De)normalized number or zero
- if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
- // use a magic value to align our 10 mantissa bits at the bottom of
- // the float. as long as FP addition is round-to-nearest-even this
- // just works.
- f.f += denorm_magic.f;
-
- // and one integer subtract of the bias later, we have our final float!
- o.x = static_cast<unsigned short>(f.u - denorm_magic.u);
- } else {
- unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
-
- // update exponent, rounding bias part 1
- f.u += ((unsigned int)(15 - 127) << 23) + 0xfff;
- // rounding bias part 2
- f.u += mant_odd;
- // take the bits!
- o.x = static_cast<unsigned short>(f.u >> 13);
- }
- }
-
- o.x |= static_cast<unsigned short>(sign >> 16);
- return o;
-#endif
-}
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
- return __half2float(h);
-
-#elif defined(EIGEN_HAS_FP16_C)
- return _cvtsh_ss(h.x);
-
-#else
- const FP32 magic = { 113 << 23 };
- const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
- FP32 o;
-
- o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
- unsigned int exp = shifted_exp & o.u; // just the exponent
- o.u += (127 - 15) << 23; // exponent adjust
-
- // handle exponent special cases
- if (exp == shifted_exp) { // Inf/NaN?
- o.u += (128 - 16) << 23; // extra exp adjust
- } else if (exp == 0) { // Zero/Denormal?
- o.u += 1 << 23; // extra exp adjust
- o.f -= magic.f; // renormalize
- }
-
- o.u |= (h.x & 0x8000) << 16; // sign bit
- return o.f;
-#endif
-}
-
-// --- standard functions ---
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
- return (a.x & 0x7fff) == 0x7c00;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
- return __hisnan(a);
-#else
- return (a.x & 0x7fff) > 0x7c00;
-#endif
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
- return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
-}
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
- half result;
- result.x = a.x & 0x7FFF;
- return result;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
- return half(::expf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined __CUDACC_VER__ && __CUDACC_VER__ >= 80000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
- return Eigen::half(::hlog(a));
-#else
- return half(::logf(float(a)));
-#endif
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
- return half(numext::log1p(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
- return half(::log10f(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
- return half(::sqrtf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
- return half(::powf(float(a), float(b)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
- return half(::sinf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) {
- return half(::cosf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) {
- return half(::tanf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) {
- return half(::tanhf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
- return half(::floorf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) {
- return half(::ceilf(float(a)));
-}
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
- return __hlt(b, a) ? b : a;
-#else
- const float f1 = static_cast<float>(a);
- const float f2 = static_cast<float>(b);
- return f2 < f1 ? b : a;
-#endif
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) {
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
- return __hlt(a, b) ? b : a;
-#else
- const float f1 = static_cast<float>(a);
- const float f2 = static_cast<float>(b);
- return f1 < f2 ? b : a;
-#endif
-}
-
-EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) {
- os << static_cast<float>(v);
- return os;
-}
-
-} // end namespace half_impl
-
-// import Eigen::half_impl::half into Eigen namespace
-// using half_impl::half;
-
-namespace internal {
-
-template<>
-struct random_default_impl<half, false, false>
-{
- static inline half run(const half& x, const half& y)
- {
- return x + (y-x) * half(float(std::rand()) / float(RAND_MAX));
- }
- static inline half run()
- {
- return run(half(-1.f), half(1.f));
- }
-};
-
-template<> struct is_arithmetic<half> { enum { value = true }; };
-
-} // end namespace internal
-
-} // end namespace Eigen
-
-namespace std {
-template<>
-struct numeric_limits<Eigen::half> {
- static const bool is_specialized = true;
- static const bool is_signed = true;
- static const bool is_integer = false;
- static const bool is_exact = false;
- static const bool has_infinity = true;
- static const bool has_quiet_NaN = true;
- static const bool has_signaling_NaN = true;
- static const float_denorm_style has_denorm = denorm_present;
- static const bool has_denorm_loss = false;
- static const std::float_round_style round_style = std::round_to_nearest;
- static const bool is_iec559 = false;
- static const bool is_bounded = false;
- static const bool is_modulo = false;
- static const int digits = 11;
- static const int digits10 = 2;
- //static const int max_digits10 = ;
- static const int radix = 2;
- static const int min_exponent = -13;
- static const int min_exponent10 = -4;
- static const int max_exponent = 16;
- static const int max_exponent10 = 4;
- static const bool traps = true;
- static const bool tinyness_before = false;
-
- static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); }
- static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
- static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
- static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); }
- static Eigen::half round_error() { return Eigen::half(0.5); }
- static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
- static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
- static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
- static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); }
-};
-}
-
-namespace Eigen {
-
-template<> struct NumTraits<Eigen::half>
- : GenericNumTraits<Eigen::half>
-{
- enum {
- IsSigned = true,
- IsInteger = false,
- IsComplex = false,
- RequireInitialization = false
- };
-
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half epsilon() {
- return half_impl::raw_uint16_to_half(0x0800);
- }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half dummy_precision() { return Eigen::half(1e-2f); }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half highest() {
- return half_impl::raw_uint16_to_half(0x7bff);
- }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half lowest() {
- return half_impl::raw_uint16_to_half(0xfbff);
- }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half infinity() {
- return half_impl::raw_uint16_to_half(0x7c00);
- }
- EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
- return half_impl::raw_uint16_to_half(0x7c01);
- }
-};
-
-} // end namespace Eigen
-
-// C-like standard mathematical functions and trancendentals.
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half fabsh(const Eigen::half& a) {
- Eigen::half result;
- result.x = a.x & 0x7FFF;
- return result;
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half exph(const Eigen::half& a) {
- return Eigen::half(::expf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half logh(const Eigen::half& a) {
-#if defined __CUDACC_VER__ && __CUDACC_VER__ >= 80000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
- return Eigen::half(::hlog(a));
-#else
- return Eigen::half(::logf(float(a)));
-#endif
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sqrth(const Eigen::half& a) {
- return Eigen::half(::sqrtf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half powh(const Eigen::half& a, const Eigen::half& b) {
- return Eigen::half(::powf(float(a), float(b)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half floorh(const Eigen::half& a) {
- return Eigen::half(::floorf(float(a)));
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half ceilh(const Eigen::half& a) {
- return Eigen::half(::ceilf(float(a)));
-}
-
-namespace std {
-
-#if __cplusplus > 199711L
-template <>
-struct hash<Eigen::half> {
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const {
- return static_cast<std::size_t>(a.x);
- }
-};
-#endif
-
-} // end namespace std
-
-
-// Add the missing shfl_xor intrinsic
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
-__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
- return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
-}
-#endif
-
-// ldg() has an overload for __half, but we also need one for Eigen::half.
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
- return Eigen::half_impl::raw_uint16_to_half(
- __ldg(reinterpret_cast<const unsigned short*>(ptr)));
-}
-#endif
-
-
-#if defined(__CUDA_ARCH__)
-namespace Eigen {
-namespace numext {
-
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-bool (isnan)(const Eigen::half& h) {
- return (half_impl::isnan)(h);
-}
-
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-bool (isinf)(const Eigen::half& h) {
- return (half_impl::isinf)(h);
-}
-
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-bool (isfinite)(const Eigen::half& h) {
- return (half_impl::isfinite)(h);
-}
-
-} // namespace Eigen
-} // namespace numext
-#endif
-
-#endif // EIGEN_HALF_CUDA_H
diff --git a/Eigen/src/Core/arch/CUDA/PacketMath.h b/Eigen/src/Core/arch/CUDA/PacketMath.h
deleted file mode 100644
index 4dda63188..000000000
--- a/Eigen/src/Core/arch/CUDA/PacketMath.h
+++ /dev/null
@@ -1,333 +0,0 @@
-// This file is part of Eigen, a lightweight C++ template library
-// for linear algebra.
-//
-// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-#ifndef EIGEN_PACKET_MATH_CUDA_H
-#define EIGEN_PACKET_MATH_CUDA_H
-
-namespace Eigen {
-
-namespace internal {
-
-// Make sure this is only available when targeting a GPU: we don't want to
-// introduce conflicts between these packet_traits definitions and the ones
-// we'll use on the host side (SSE, AVX, ...)
-#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
-template<> struct is_arithmetic<float4> { enum { value = true }; };
-template<> struct is_arithmetic<double2> { enum { value = true }; };
-
-template<> struct packet_traits<float> : default_packet_traits
-{
- typedef float4 type;
- typedef float4 half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size=4,
- HasHalfPacket = 0,
-
- HasDiv = 1,
- HasSin = 0,
- HasCos = 0,
- HasLog = 1,
- HasExp = 1,
- HasSqrt = 1,
- HasRsqrt = 1,
- HasLGamma = 1,
- HasDiGamma = 1,
- HasZeta = 1,
- HasPolygamma = 1,
- HasErf = 1,
- HasErfc = 1,
- HasIGamma = 1,
- HasIGammac = 1,
- HasBetaInc = 1,
-
- HasBlend = 0,
- };
-};
-
-template<> struct packet_traits<double> : default_packet_traits
-{
- typedef double2 type;
- typedef double2 half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size=2,
- HasHalfPacket = 0,
-
- HasDiv = 1,
- HasLog = 1,
- HasExp = 1,
- HasSqrt = 1,
- HasRsqrt = 1,
- HasLGamma = 1,
- HasDiGamma = 1,
- HasZeta = 1,
- HasPolygamma = 1,
- HasErf = 1,
- HasErfc = 1,
- HasIGamma = 1,
- HasIGammac = 1,
- HasBetaInc = 1,
-
- HasBlend = 0,
- };
-};
-
-
-template<> struct unpacket_traits<float4> { typedef float type; enum {size=4, alignment=Aligned16}; typedef float4 half; };
-template<> struct unpacket_traits<double2> { typedef double type; enum {size=2, alignment=Aligned16}; typedef double2 half; };
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pset1<float4>(const float& from) {
- return make_float4(from, from, from, from);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pset1<double2>(const double& from) {
- return make_double2(from, from);
-}
-
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
- return make_float4(a, a+1, a+2, a+3);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 plset<double2>(const double& a) {
- return make_double2(a, a+1);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 padd<float4>(const float4& a, const float4& b) {
- return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 padd<double2>(const double2& a, const double2& b) {
- return make_double2(a.x+b.x, a.y+b.y);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 psub<float4>(const float4& a, const float4& b) {
- return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 psub<double2>(const double2& a, const double2& b) {
- return make_double2(a.x-b.x, a.y-b.y);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pnegate(const float4& a) {
- return make_float4(-a.x, -a.y, -a.z, -a.w);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pnegate(const double2& a) {
- return make_double2(-a.x, -a.y);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pconj(const float4& a) { return a; }
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pconj(const double2& a) { return a; }
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmul<float4>(const float4& a, const float4& b) {
- return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmul<double2>(const double2& a, const double2& b) {
- return make_double2(a.x*b.x, a.y*b.y);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pdiv<float4>(const float4& a, const float4& b) {
- return make_float4(a.x/b.x, a.y/b.y, a.z/b.z, a.w/b.w);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pdiv<double2>(const double2& a, const double2& b) {
- return make_double2(a.x/b.x, a.y/b.y);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmin<float4>(const float4& a, const float4& b) {
- return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmin<double2>(const double2& a, const double2& b) {
- return make_double2(fmin(a.x, b.x), fmin(a.y, b.y));
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmax<float4>(const float4& a, const float4& b) {
- return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmax<double2>(const double2& a, const double2& b) {
- return make_double2(fmax(a.x, b.x), fmax(a.y, b.y));
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pload<float4>(const float* from) {
- return *reinterpret_cast<const float4*>(from);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pload<double2>(const double* from) {
- return *reinterpret_cast<const double2*>(from);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploadu<float4>(const float* from) {
- return make_float4(from[0], from[1], from[2], from[3]);
-}
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploadu<double2>(const double* from) {
- return make_double2(from[0], from[1]);
-}
-
-template<> EIGEN_STRONG_INLINE float4 ploaddup<float4>(const float* from) {
- return make_float4(from[0], from[0], from[1], from[1]);
-}
-template<> EIGEN_STRONG_INLINE double2 ploaddup<double2>(const double* from) {
- return make_double2(from[0], from[0]);
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<float>(float* to, const float4& from) {
- *reinterpret_cast<float4*>(to) = from;
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<double>(double* to, const double2& from) {
- *reinterpret_cast<double2*>(to) = from;
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const float4& from) {
- to[0] = from.x;
- to[1] = from.y;
- to[2] = from.z;
- to[3] = from.w;
-}
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const double2& from) {
- to[0] = from.x;
- to[1] = from.y;
-}
-
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Aligned>(const float* from) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return __ldg((const float4*)from);
-#else
- return make_float4(from[0], from[1], from[2], from[3]);
-#endif
-}
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Aligned>(const double* from) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return __ldg((const double2*)from);
-#else
- return make_double2(from[0], from[1]);
-#endif
-}
-
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Unaligned>(const float* from) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return make_float4(__ldg(from+0), __ldg(from+1), __ldg(from+2), __ldg(from+3));
-#else
- return make_float4(from[0], from[1], from[2], from[3]);
-#endif
-}
-template<>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Unaligned>(const double* from) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return make_double2(__ldg(from+0), __ldg(from+1));
-#else
- return make_double2(from[0], from[1]);
-#endif
-}
-
-template<> EIGEN_DEVICE_FUNC inline float4 pgather<float, float4>(const float* from, Index stride) {
- return make_float4(from[0*stride], from[1*stride], from[2*stride], from[3*stride]);
-}
-
-template<> EIGEN_DEVICE_FUNC inline double2 pgather<double, double2>(const double* from, Index stride) {
- return make_double2(from[0*stride], from[1*stride]);
-}
-
-template<> EIGEN_DEVICE_FUNC inline void pscatter<float, float4>(float* to, const float4& from, Index stride) {
- to[stride*0] = from.x;
- to[stride*1] = from.y;
- to[stride*2] = from.z;
- to[stride*3] = from.w;
-}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<double, double2>(double* to, const double2& from, Index stride) {
- to[stride*0] = from.x;
- to[stride*1] = from.y;
-}
-
-template<> EIGEN_DEVICE_FUNC inline float pfirst<float4>(const float4& a) {
- return a.x;
-}
-template<> EIGEN_DEVICE_FUNC inline double pfirst<double2>(const double2& a) {
- return a.x;
-}
-
-template<> EIGEN_DEVICE_FUNC inline float predux<float4>(const float4& a) {
- return a.x + a.y + a.z + a.w;
-}
-template<> EIGEN_DEVICE_FUNC inline double predux<double2>(const double2& a) {
- return a.x + a.y;
-}
-
-template<> EIGEN_DEVICE_FUNC inline float predux_max<float4>(const float4& a) {
- return fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w));
-}
-template<> EIGEN_DEVICE_FUNC inline double predux_max<double2>(const double2& a) {
- return fmax(a.x, a.y);
-}
-
-template<> EIGEN_DEVICE_FUNC inline float predux_min<float4>(const float4& a) {
- return fminf(fminf(a.x, a.y), fminf(a.z, a.w));
-}
-template<> EIGEN_DEVICE_FUNC inline double predux_min<double2>(const double2& a) {
- return fmin(a.x, a.y);
-}
-
-template<> EIGEN_DEVICE_FUNC inline float predux_mul<float4>(const float4& a) {
- return a.x * a.y * a.z * a.w;
-}
-template<> EIGEN_DEVICE_FUNC inline double predux_mul<double2>(const double2& a) {
- return a.x * a.y;
-}
-
-template<> EIGEN_DEVICE_FUNC inline float4 pabs<float4>(const float4& a) {
- return make_float4(fabsf(a.x), fabsf(a.y), fabsf(a.z), fabsf(a.w));
-}
-template<> EIGEN_DEVICE_FUNC inline double2 pabs<double2>(const double2& a) {
- return make_double2(fabs(a.x), fabs(a.y));
-}
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<float4,4>& kernel) {
- float tmp = kernel.packet[0].y;
- kernel.packet[0].y = kernel.packet[1].x;
- kernel.packet[1].x = tmp;
-
- tmp = kernel.packet[0].z;
- kernel.packet[0].z = kernel.packet[2].x;
- kernel.packet[2].x = tmp;
-
- tmp = kernel.packet[0].w;
- kernel.packet[0].w = kernel.packet[3].x;
- kernel.packet[3].x = tmp;
-
- tmp = kernel.packet[1].z;
- kernel.packet[1].z = kernel.packet[2].y;
- kernel.packet[2].y = tmp;
-
- tmp = kernel.packet[1].w;
- kernel.packet[1].w = kernel.packet[3].y;
- kernel.packet[3].y = tmp;
-
- tmp = kernel.packet[2].w;
- kernel.packet[2].w = kernel.packet[3].z;
- kernel.packet[3].z = tmp;
-}
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<double2,2>& kernel) {
- double tmp = kernel.packet[0].y;
- kernel.packet[0].y = kernel.packet[1].x;
- kernel.packet[1].x = tmp;
-}
-
-#endif
-
-} // end namespace internal
-
-} // end namespace Eigen
-
-
-#endif // EIGEN_PACKET_MATH_CUDA_H
diff --git a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h
deleted file mode 100644
index ae54225f8..000000000
--- a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h
+++ /dev/null
@@ -1,1123 +0,0 @@
-// This file is part of Eigen, a lightweight C++ template library
-// for linear algebra.
-//
-// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-#ifndef EIGEN_PACKET_MATH_HALF_CUDA_H
-#define EIGEN_PACKET_MATH_HALF_CUDA_H
-
-
-namespace Eigen {
-namespace internal {
-
-// Most of the following operations require arch >= 3.0
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDACC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
-
-template<> struct is_arithmetic<half2> { enum { value = true }; };
-
-template<> struct packet_traits<Eigen::half> : default_packet_traits
-{
- typedef half2 type;
- typedef half2 half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size=2,
- HasHalfPacket = 0,
- HasAdd = 1,
- HasMul = 1,
- HasDiv = 1,
- HasSqrt = 1,
- HasRsqrt = 1,
- HasExp = 1,
- HasLog = 1,
- HasLog1p = 1
- };
-};
-
-template<> struct unpacket_traits<half2> { typedef Eigen::half type; enum {size=2, alignment=Aligned16}; typedef half2 half; };
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pset1<half2>(const Eigen::half& from) {
- return __half2half2(from);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pload<half2>(const Eigen::half* from) {
- return *reinterpret_cast<const half2*>(from);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 ploadu<half2>(const Eigen::half* from) {
- return __halves2half2(from[0], from[1]);
-}
-
-template<> EIGEN_STRONG_INLINE half2 ploaddup<half2>(const Eigen::half* from) {
- return __halves2half2(from[0], from[0]);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const half2& from) {
- *reinterpret_cast<half2*>(to) = from;
-}
-
-template<> __device__ EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const half2& from) {
- to[0] = __low2half(from);
- to[1] = __high2half(from);
-}
-
-template<>
- __device__ EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Aligned>(const Eigen::half* from) {
-#if __CUDA_ARCH__ >= 350
- return __ldg((const half2*)from);
-#else
- return __halves2half2(*(from+0), *(from+1));
-#endif
-}
-
-template<>
-__device__ EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Unaligned>(const Eigen::half* from) {
-#if __CUDA_ARCH__ >= 350
- return __halves2half2(__ldg(from+0), __ldg(from+1));
-#else
- return __halves2half2(*(from+0), *(from+1));
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pgather<Eigen::half, half2>(const Eigen::half* from, Index stride) {
- return __halves2half2(from[0*stride], from[1*stride]);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE void pscatter<Eigen::half, half2>(Eigen::half* to, const half2& from, Index stride) {
- to[stride*0] = __low2half(from);
- to[stride*1] = __high2half(from);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE Eigen::half pfirst<half2>(const half2& a) {
- return __low2half(a);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pabs<half2>(const half2& a) {
- half2 result;
- result.x = a.x & 0x7FFF7FFF;
- return result;
-}
-
-
-__device__ EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<half2,2>& kernel) {
- __half a1 = __low2half(kernel.packet[0]);
- __half a2 = __high2half(kernel.packet[0]);
- __half b1 = __low2half(kernel.packet[1]);
- __half b2 = __high2half(kernel.packet[1]);
- kernel.packet[0] = __halves2half2(a1, b1);
- kernel.packet[1] = __halves2half2(a2, b2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 plset<half2>(const Eigen::half& a) {
-#if __CUDA_ARCH__ >= 530
- return __halves2half2(a, __hadd(a, __float2half(1.0f)));
-#else
- float f = __half2float(a) + 1.0f;
- return __halves2half2(a, __float2half(f));
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a, const half2& b) {
-#if __CUDA_ARCH__ >= 530
- return __hadd2(a, b);
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- float r1 = a1 + b1;
- float r2 = a2 + b2;
- return __floats2half2_rn(r1, r2);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 psub<half2>(const half2& a, const half2& b) {
-#if __CUDA_ARCH__ >= 530
- return __hsub2(a, b);
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- float r1 = a1 - b1;
- float r2 = a2 - b2;
- return __floats2half2_rn(r1, r2);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pnegate(const half2& a) {
-#if __CUDA_ARCH__ >= 530
- return __hneg2(a);
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- return __floats2half2_rn(-a1, -a2);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; }
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pmul<half2>(const half2& a, const half2& b) {
-#if __CUDA_ARCH__ >= 530
- return __hmul2(a, b);
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- float r1 = a1 * b1;
- float r2 = a2 * b2;
- return __floats2half2_rn(r1, r2);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pmadd<half2>(const half2& a, const half2& b, const half2& c) {
-#if __CUDA_ARCH__ >= 530
- return __hfma2(a, b, c);
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- float c1 = __low2float(c);
- float c2 = __high2float(c);
- float r1 = a1 * b1 + c1;
- float r2 = a2 * b2 + c2;
- return __floats2half2_rn(r1, r2);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pdiv<half2>(const half2& a, const half2& b) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- float r1 = a1 / b1;
- float r2 = a2 / b2;
- return __floats2half2_rn(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pmin<half2>(const half2& a, const half2& b) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- __half r1 = a1 < b1 ? __low2half(a) : __low2half(b);
- __half r2 = a2 < b2 ? __high2half(a) : __high2half(b);
- return __halves2half2(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pmax<half2>(const half2& a, const half2& b) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float b1 = __low2float(b);
- float b2 = __high2float(b);
- __half r1 = a1 > b1 ? __low2half(a) : __low2half(b);
- __half r2 = a2 > b2 ? __high2half(a) : __high2half(b);
- return __halves2half2(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE Eigen::half predux<half2>(const half2& a) {
-#if __CUDA_ARCH__ >= 530
- return __hadd(__low2half(a), __high2half(a));
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- return Eigen::half(half_impl::raw_uint16_to_half(__float2half_rn(a1 + a2)));
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE Eigen::half predux_max<half2>(const half2& a) {
-#if __CUDA_ARCH__ >= 530
- __half first = __low2half(a);
- __half second = __high2half(a);
- return __hgt(first, second) ? first : second;
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- return a1 > a2 ? __low2half(a) : __high2half(a);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE Eigen::half predux_min<half2>(const half2& a) {
-#if __CUDA_ARCH__ >= 530
- __half first = __low2half(a);
- __half second = __high2half(a);
- return __hlt(first, second) ? first : second;
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- return a1 < a2 ? __low2half(a) : __high2half(a);
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE Eigen::half predux_mul<half2>(const half2& a) {
-#if __CUDA_ARCH__ >= 530
- return __hmul(__low2half(a), __high2half(a));
-#else
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- return Eigen::half(half_impl::raw_uint16_to_half(__float2half_rn(a1 * a2)));
-#endif
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 plog1p<half2>(const half2& a) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float r1 = log1pf(a1);
- float r2 = log1pf(a2);
- return __floats2half2_rn(r1, r2);
-}
-
-#if defined __CUDACC_VER__ && __CUDACC_VER__ >= 80000 && defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 530
-
-template<> __device__ EIGEN_STRONG_INLINE
-half2 plog<half2>(const half2& a) {
- return h2log(a);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE
-half2 pexp<half2>(const half2& a) {
- return h2exp(a);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE
-half2 psqrt<half2>(const half2& a) {
- return h2sqrt(a);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE
-half2 prsqrt<half2>(const half2& a) {
- return h2rsqrt(a);
-}
-
-#else
-
-template<> __device__ EIGEN_STRONG_INLINE half2 plog<half2>(const half2& a) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float r1 = logf(a1);
- float r2 = logf(a2);
- return __floats2half2_rn(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 pexp<half2>(const half2& a) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float r1 = expf(a1);
- float r2 = expf(a2);
- return __floats2half2_rn(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 psqrt<half2>(const half2& a) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float r1 = sqrtf(a1);
- float r2 = sqrtf(a2);
- return __floats2half2_rn(r1, r2);
-}
-
-template<> __device__ EIGEN_STRONG_INLINE half2 prsqrt<half2>(const half2& a) {
- float a1 = __low2float(a);
- float a2 = __high2float(a);
- float r1 = rsqrtf(a1);
- float r2 = rsqrtf(a2);
- return __floats2half2_rn(r1, r2);
-}
-
-#endif
-
-#elif defined EIGEN_VECTORIZE_AVX512
-
-typedef struct {
- __m256i x;
-} Packet16h;
-
-
-template<> struct is_arithmetic<Packet16h> { enum { value = true }; };
-
-template <>
-struct packet_traits<half> : default_packet_traits {
- typedef Packet16h type;
- // There is no half-size packet for Packet16h.
- typedef Packet16h half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size = 16,
- HasHalfPacket = 0,
- HasAdd = 0,
- HasSub = 0,
- HasMul = 0,
- HasNegate = 0,
- HasAbs = 0,
- HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
- HasSetLinear = 0,
- HasDiv = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
- };
-};
-
-
-template<> struct unpacket_traits<Packet16h> { typedef Eigen::half type; enum {size=16, alignment=Aligned32}; typedef Packet16h half; };
-
-template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
- Packet16h result;
- result.x = _mm256_set1_epi16(from.x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
- return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from.x, 0)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
- Packet16h result;
- result.x = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
- Packet16h result;
- result.x = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
- _mm256_store_si256((__m256i*)to, from.x);
-}
-
-template<> EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
- _mm256_storeu_si256((__m256i*)to, from.x);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h
-ploadquad(const Eigen::half* from) {
- Packet16h result;
- unsigned short a = from[0].x;
- unsigned short b = from[1].x;
- unsigned short c = from[2].x;
- unsigned short d = from[3].x;
- result.x = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
- return result;
-}
-
-EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) {
-#ifdef EIGEN_HAS_FP16_C
- return _mm512_cvtph_ps(a.x);
-#else
- EIGEN_ALIGN64 half aux[16];
- pstore(aux, a);
- float f0(aux[0]);
- float f1(aux[1]);
- float f2(aux[2]);
- float f3(aux[3]);
- float f4(aux[4]);
- float f5(aux[5]);
- float f6(aux[6]);
- float f7(aux[7]);
- float f8(aux[8]);
- float f9(aux[9]);
- float fa(aux[10]);
- float fb(aux[11]);
- float fc(aux[12]);
- float fd(aux[13]);
- float fe(aux[14]);
- float ff(aux[15]);
-
- return _mm512_set_ps(
- ff, fe, fd, fc, fb, fa, f9, f8, f7, f6, f5, f4, f3, f2, f1, f0);
-#endif
-}
-
-EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) {
-#ifdef EIGEN_HAS_FP16_C
- Packet16h result;
- result.x = _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
- return result;
-#else
- EIGEN_ALIGN64 float aux[16];
- pstore(aux, a);
- half h0(aux[0]);
- half h1(aux[1]);
- half h2(aux[2]);
- half h3(aux[3]);
- half h4(aux[4]);
- half h5(aux[5]);
- half h6(aux[6]);
- half h7(aux[7]);
- half h8(aux[8]);
- half h9(aux[9]);
- half ha(aux[10]);
- half hb(aux[11]);
- half hc(aux[12]);
- half hd(aux[13]);
- half he(aux[14]);
- half hf(aux[15]);
-
- Packet16h result;
- result.x = _mm256_set_epi16(
- hf.x, he.x, hd.x, hc.x, hb.x, ha.x, h9.x, h8.x,
- h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x);
- return result;
-#endif
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
- Packet16f af = half2float(a);
- Packet16f bf = half2float(b);
- Packet16f rf = padd(af, bf);
- return float2half(rf);
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
- Packet16f af = half2float(a);
- Packet16f bf = half2float(b);
- Packet16f rf = pmul(af, bf);
- return float2half(rf);
-}
-
-template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
- Packet16f from_float = half2float(from);
- return half(predux(from_float));
-}
-
-template<> EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride)
-{
- Packet16h result;
- result.x = _mm256_set_epi16(
- from[15*stride].x, from[14*stride].x, from[13*stride].x, from[12*stride].x,
- from[11*stride].x, from[10*stride].x, from[9*stride].x, from[8*stride].x,
- from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x,
- from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride)
-{
- EIGEN_ALIGN64 half aux[16];
- pstore(aux, from);
- to[stride*0].x = aux[0].x;
- to[stride*1].x = aux[1].x;
- to[stride*2].x = aux[2].x;
- to[stride*3].x = aux[3].x;
- to[stride*4].x = aux[4].x;
- to[stride*5].x = aux[5].x;
- to[stride*6].x = aux[6].x;
- to[stride*7].x = aux[7].x;
- to[stride*8].x = aux[8].x;
- to[stride*9].x = aux[9].x;
- to[stride*10].x = aux[10].x;
- to[stride*11].x = aux[11].x;
- to[stride*12].x = aux[12].x;
- to[stride*13].x = aux[13].x;
- to[stride*14].x = aux[14].x;
- to[stride*15].x = aux[15].x;
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet16h,16>& kernel) {
- __m256i a = kernel.packet[0].x;
- __m256i b = kernel.packet[1].x;
- __m256i c = kernel.packet[2].x;
- __m256i d = kernel.packet[3].x;
- __m256i e = kernel.packet[4].x;
- __m256i f = kernel.packet[5].x;
- __m256i g = kernel.packet[6].x;
- __m256i h = kernel.packet[7].x;
- __m256i i = kernel.packet[8].x;
- __m256i j = kernel.packet[9].x;
- __m256i k = kernel.packet[10].x;
- __m256i l = kernel.packet[11].x;
- __m256i m = kernel.packet[12].x;
- __m256i n = kernel.packet[13].x;
- __m256i o = kernel.packet[14].x;
- __m256i p = kernel.packet[15].x;
-
- __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
- __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
- __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
- __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
- __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
- __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
- __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
- __m256i op_07 = _mm256_unpacklo_epi16(o, p);
-
- __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
- __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
- __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
- __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
- __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
- __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
- __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
- __m256i op_8f = _mm256_unpackhi_epi16(o, p);
-
- __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
- __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
- __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
- __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
- __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
- __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
- __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
- __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
-
- __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
- __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
- __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
- __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
- __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
- __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
- __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
- __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
-
- __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
- __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
- __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
- __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
- __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
- __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
- __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
- __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
- __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
- __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
- __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
- __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
- __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
- __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
- __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
- __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
-
- // NOTE: no unpacklo/hi instr in this case, so using permute instr.
- __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
- __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
- __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
- __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
- __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
- __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
- __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
- __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
- __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
- __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
- __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
- __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
- __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
- __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
- __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
- __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
-
- kernel.packet[0].x = a_p_0;
- kernel.packet[1].x = a_p_1;
- kernel.packet[2].x = a_p_2;
- kernel.packet[3].x = a_p_3;
- kernel.packet[4].x = a_p_4;
- kernel.packet[5].x = a_p_5;
- kernel.packet[6].x = a_p_6;
- kernel.packet[7].x = a_p_7;
- kernel.packet[8].x = a_p_8;
- kernel.packet[9].x = a_p_9;
- kernel.packet[10].x = a_p_a;
- kernel.packet[11].x = a_p_b;
- kernel.packet[12].x = a_p_c;
- kernel.packet[13].x = a_p_d;
- kernel.packet[14].x = a_p_e;
- kernel.packet[15].x = a_p_f;
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet16h,8>& kernel) {
- EIGEN_ALIGN64 half in[8][16];
- pstore<half>(in[0], kernel.packet[0]);
- pstore<half>(in[1], kernel.packet[1]);
- pstore<half>(in[2], kernel.packet[2]);
- pstore<half>(in[3], kernel.packet[3]);
- pstore<half>(in[4], kernel.packet[4]);
- pstore<half>(in[5], kernel.packet[5]);
- pstore<half>(in[6], kernel.packet[6]);
- pstore<half>(in[7], kernel.packet[7]);
-
- EIGEN_ALIGN64 half out[8][16];
-
- for (int i = 0; i < 8; ++i) {
- for (int j = 0; j < 8; ++j) {
- out[i][j] = in[j][2*i];
- }
- for (int j = 0; j < 8; ++j) {
- out[i][j+8] = in[j][2*i+1];
- }
- }
-
- kernel.packet[0] = pload<Packet16h>(out[0]);
- kernel.packet[1] = pload<Packet16h>(out[1]);
- kernel.packet[2] = pload<Packet16h>(out[2]);
- kernel.packet[3] = pload<Packet16h>(out[3]);
- kernel.packet[4] = pload<Packet16h>(out[4]);
- kernel.packet[5] = pload<Packet16h>(out[5]);
- kernel.packet[6] = pload<Packet16h>(out[6]);
- kernel.packet[7] = pload<Packet16h>(out[7]);
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet16h,4>& kernel) {
- EIGEN_ALIGN64 half in[4][16];
- pstore<half>(in[0], kernel.packet[0]);
- pstore<half>(in[1], kernel.packet[1]);
- pstore<half>(in[2], kernel.packet[2]);
- pstore<half>(in[3], kernel.packet[3]);
-
- EIGEN_ALIGN64 half out[4][16];
-
- for (int i = 0; i < 4; ++i) {
- for (int j = 0; j < 4; ++j) {
- out[i][j] = in[j][4*i];
- }
- for (int j = 0; j < 4; ++j) {
- out[i][j+4] = in[j][4*i+1];
- }
- for (int j = 0; j < 4; ++j) {
- out[i][j+8] = in[j][4*i+2];
- }
- for (int j = 0; j < 4; ++j) {
- out[i][j+12] = in[j][4*i+3];
- }
- }
-
- kernel.packet[0] = pload<Packet16h>(out[0]);
- kernel.packet[1] = pload<Packet16h>(out[1]);
- kernel.packet[2] = pload<Packet16h>(out[2]);
- kernel.packet[3] = pload<Packet16h>(out[3]);
-}
-
-
-#elif defined EIGEN_VECTORIZE_AVX
-
-typedef struct {
- __m128i x;
-} Packet8h;
-
-
-template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
-
-template <>
-struct packet_traits<Eigen::half> : default_packet_traits {
- typedef Packet8h type;
- // There is no half-size packet for Packet8h.
- typedef Packet8h half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size = 8,
- HasHalfPacket = 0,
- HasAdd = 0,
- HasSub = 0,
- HasMul = 0,
- HasNegate = 0,
- HasAbs = 0,
- HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
- HasSetLinear = 0,
- HasDiv = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
- };
-};
-
-
-template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16}; typedef Packet8h half; };
-
-template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
- Packet8h result;
- result.x = _mm_set1_epi16(from.x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
- return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_extract_epi16(from.x, 0)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
- Packet8h result;
- result.x = _mm_load_si128(reinterpret_cast<const __m128i*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
- Packet8h result;
- result.x = _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) {
- _mm_store_si128(reinterpret_cast<__m128i*>(to), from.x);
-}
-
-template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) {
- _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from.x);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h
-ploadquad<Packet8h>(const Eigen::half* from) {
- Packet8h result;
- unsigned short a = from[0].x;
- unsigned short b = from[1].x;
- result.x = _mm_set_epi16(b, b, b, b, a, a, a, a);
- return result;
-}
-
-EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
-#ifdef EIGEN_HAS_FP16_C
- return _mm256_cvtph_ps(a.x);
-#else
- EIGEN_ALIGN32 Eigen::half aux[8];
- pstore(aux, a);
- float f0(aux[0]);
- float f1(aux[1]);
- float f2(aux[2]);
- float f3(aux[3]);
- float f4(aux[4]);
- float f5(aux[5]);
- float f6(aux[6]);
- float f7(aux[7]);
-
- return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0);
-#endif
-}
-
-EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
-#ifdef EIGEN_HAS_FP16_C
- Packet8h result;
- result.x = _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
- return result;
-#else
- EIGEN_ALIGN32 float aux[8];
- pstore(aux, a);
- Eigen::half h0(aux[0]);
- Eigen::half h1(aux[1]);
- Eigen::half h2(aux[2]);
- Eigen::half h3(aux[3]);
- Eigen::half h4(aux[4]);
- Eigen::half h5(aux[5]);
- Eigen::half h6(aux[6]);
- Eigen::half h7(aux[7]);
-
- Packet8h result;
- result.x = _mm_set_epi16(h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x);
- return result;
-#endif
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; }
-
-template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
- Packet8f af = half2float(a);
- Packet8f bf = half2float(b);
- Packet8f rf = padd(af, bf);
- return float2half(rf);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
- Packet8f af = half2float(a);
- Packet8f bf = half2float(b);
- Packet8f rf = pmul(af, bf);
- return float2half(rf);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
-{
- Packet8h result;
- result.x = _mm_set_epi16(from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride)
-{
- EIGEN_ALIGN32 Eigen::half aux[8];
- pstore(aux, from);
- to[stride*0].x = aux[0].x;
- to[stride*1].x = aux[1].x;
- to[stride*2].x = aux[2].x;
- to[stride*3].x = aux[3].x;
- to[stride*4].x = aux[4].x;
- to[stride*5].x = aux[5].x;
- to[stride*6].x = aux[6].x;
- to[stride*7].x = aux[7].x;
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
- Packet8f af = half2float(a);
- float reduced = predux<Packet8f>(af);
- return Eigen::half(reduced);
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
- Packet8f af = half2float(a);
- float reduced = predux_max<Packet8f>(af);
- return Eigen::half(reduced);
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8h>(const Packet8h& a) {
- Packet8f af = half2float(a);
- float reduced = predux_min<Packet8f>(af);
- return Eigen::half(reduced);
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8h>(const Packet8h& a) {
- Packet8f af = half2float(a);
- float reduced = predux_mul<Packet8f>(af);
- return Eigen::half(reduced);
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet8h,8>& kernel) {
- __m128i a = kernel.packet[0].x;
- __m128i b = kernel.packet[1].x;
- __m128i c = kernel.packet[2].x;
- __m128i d = kernel.packet[3].x;
- __m128i e = kernel.packet[4].x;
- __m128i f = kernel.packet[5].x;
- __m128i g = kernel.packet[6].x;
- __m128i h = kernel.packet[7].x;
-
- __m128i a03b03 = _mm_unpacklo_epi16(a, b);
- __m128i c03d03 = _mm_unpacklo_epi16(c, d);
- __m128i e03f03 = _mm_unpacklo_epi16(e, f);
- __m128i g03h03 = _mm_unpacklo_epi16(g, h);
- __m128i a47b47 = _mm_unpackhi_epi16(a, b);
- __m128i c47d47 = _mm_unpackhi_epi16(c, d);
- __m128i e47f47 = _mm_unpackhi_epi16(e, f);
- __m128i g47h47 = _mm_unpackhi_epi16(g, h);
-
- __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
- __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
- __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
- __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
- __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
- __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
- __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
- __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
-
- __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
- __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
- __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
- __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
- __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
- __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
- __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
- __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
-
- kernel.packet[0].x = a0b0c0d0e0f0g0h0;
- kernel.packet[1].x = a1b1c1d1e1f1g1h1;
- kernel.packet[2].x = a2b2c2d2e2f2g2h2;
- kernel.packet[3].x = a3b3c3d3e3f3g3h3;
- kernel.packet[4].x = a4b4c4d4e4f4g4h4;
- kernel.packet[5].x = a5b5c5d5e5f5g5h5;
- kernel.packet[6].x = a6b6c6d6e6f6g6h6;
- kernel.packet[7].x = a7b7c7d7e7f7g7h7;
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet8h,4>& kernel) {
- EIGEN_ALIGN32 Eigen::half in[4][8];
- pstore<Eigen::half>(in[0], kernel.packet[0]);
- pstore<Eigen::half>(in[1], kernel.packet[1]);
- pstore<Eigen::half>(in[2], kernel.packet[2]);
- pstore<Eigen::half>(in[3], kernel.packet[3]);
-
- EIGEN_ALIGN32 Eigen::half out[4][8];
-
- for (int i = 0; i < 4; ++i) {
- for (int j = 0; j < 4; ++j) {
- out[i][j] = in[j][2*i];
- }
- for (int j = 0; j < 4; ++j) {
- out[i][j+4] = in[j][2*i+1];
- }
- }
-
- kernel.packet[0] = pload<Packet8h>(out[0]);
- kernel.packet[1] = pload<Packet8h>(out[1]);
- kernel.packet[2] = pload<Packet8h>(out[2]);
- kernel.packet[3] = pload<Packet8h>(out[3]);
-}
-
-
-// Disable the following code since it's broken on too many platforms / compilers.
-//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
-#elif 0
-
-typedef struct {
- __m64 x;
-} Packet4h;
-
-
-template<> struct is_arithmetic<Packet4h> { enum { value = true }; };
-
-template <>
-struct packet_traits<Eigen::half> : default_packet_traits {
- typedef Packet4h type;
- // There is no half-size packet for Packet4h.
- typedef Packet4h half;
- enum {
- Vectorizable = 1,
- AlignedOnScalar = 1,
- size = 4,
- HasHalfPacket = 0,
- HasAdd = 0,
- HasSub = 0,
- HasMul = 0,
- HasNegate = 0,
- HasAbs = 0,
- HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
- HasConj = 0,
- HasSetLinear = 0,
- HasDiv = 0,
- HasSqrt = 0,
- HasRsqrt = 0,
- HasExp = 0,
- HasLog = 0,
- HasBlend = 0
- };
-};
-
-
-template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16}; typedef Packet4h half; };
-
-template<> EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) {
- Packet4h result;
- result.x = _mm_set1_pi16(from.x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h>(const Packet4h& from) {
- return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_cvtsi64_si32(from.x)));
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; }
-
-template<> EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const Packet4h& b) {
- __int64_t a64 = _mm_cvtm64_si64(a.x);
- __int64_t b64 = _mm_cvtm64_si64(b.x);
-
- Eigen::half h[4];
-
- Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
- Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
- h[0] = ha + hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
- h[1] = ha + hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
- h[2] = ha + hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
- h[3] = ha + hb;
- Packet4h result;
- result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) {
- __int64_t a64 = _mm_cvtm64_si64(a.x);
- __int64_t b64 = _mm_cvtm64_si64(b.x);
-
- Eigen::half h[4];
-
- Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
- Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
- h[0] = ha * hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
- h[1] = ha * hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
- h[2] = ha * hb;
- ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
- hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
- h[3] = ha * hb;
- Packet4h result;
- result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) {
- Packet4h result;
- result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h ploadu<Packet4h>(const Eigen::half* from) {
- Packet4h result;
- result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h& from) {
- __int64_t r = _mm_cvtm64_si64(from.x);
- *(reinterpret_cast<__int64_t*>(to)) = r;
-}
-
-template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h& from) {
- __int64_t r = _mm_cvtm64_si64(from.x);
- *(reinterpret_cast<__int64_t*>(to)) = r;
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h
-ploadquad<Packet4h>(const Eigen::half* from) {
- return pset1<Packet4h>(*from);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4h pgather<Eigen::half, Packet4h>(const Eigen::half* from, Index stride)
-{
- Packet4h result;
- result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
- return result;
-}
-
-template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h>(Eigen::half* to, const Packet4h& from, Index stride)
-{
- __int64_t a = _mm_cvtm64_si64(from.x);
- to[stride*0].x = static_cast<unsigned short>(a);
- to[stride*1].x = static_cast<unsigned short>(a >> 16);
- to[stride*2].x = static_cast<unsigned short>(a >> 32);
- to[stride*3].x = static_cast<unsigned short>(a >> 48);
-}
-
-EIGEN_STRONG_INLINE void
-ptranspose(PacketBlock<Packet4h,4>& kernel) {
- __m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x);
- __m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x);
- __m64 T2 = _mm_unpackhi_pi16(kernel.packet[0].x, kernel.packet[1].x);
- __m64 T3 = _mm_unpackhi_pi16(kernel.packet[2].x, kernel.packet[3].x);
-
- kernel.packet[0].x = _mm_unpacklo_pi32(T0, T1);
- kernel.packet[1].x = _mm_unpackhi_pi32(T0, T1);
- kernel.packet[2].x = _mm_unpacklo_pi32(T2, T3);
- kernel.packet[3].x = _mm_unpackhi_pi32(T2, T3);
-}
-
-#endif
-
-}
-}
-
-#endif // EIGEN_PACKET_MATH_HALF_CUDA_H
diff --git a/Eigen/src/Core/arch/CUDA/TypeCasting.h b/Eigen/src/Core/arch/CUDA/TypeCasting.h
deleted file mode 100644
index aa5fbce8e..000000000
--- a/Eigen/src/Core/arch/CUDA/TypeCasting.h
+++ /dev/null
@@ -1,212 +0,0 @@
-// This file is part of Eigen, a lightweight C++ template library
-// for linear algebra.
-//
-// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
-//
-// This Source Code Form is subject to the terms of the Mozilla
-// Public License v. 2.0. If a copy of the MPL was not distributed
-// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-#ifndef EIGEN_TYPE_CASTING_CUDA_H
-#define EIGEN_TYPE_CASTING_CUDA_H
-
-namespace Eigen {
-
-namespace internal {
-
-template<>
-struct scalar_cast_op<float, Eigen::half> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
- typedef Eigen::half result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
- #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
- return __float2half(a);
- #else
- return Eigen::half(a);
- #endif
- }
-};
-
-template<>
-struct functor_traits<scalar_cast_op<float, Eigen::half> >
-{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
-
-
-template<>
-struct scalar_cast_op<int, Eigen::half> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
- typedef Eigen::half result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
- #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
- return __float2half(static_cast<float>(a));
- #else
- return Eigen::half(static_cast<float>(a));
- #endif
- }
-};
-
-template<>
-struct functor_traits<scalar_cast_op<int, Eigen::half> >
-{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
-
-
-template<>
-struct scalar_cast_op<Eigen::half, float> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
- typedef float result_type;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
- #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
- return __half2float(a);
- #else
- return static_cast<float>(a);
- #endif
- }
-};
-
-template<>
-struct functor_traits<scalar_cast_op<Eigen::half, float> >
-{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
-
-
-
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
-
-template <>
-struct type_casting_traits<Eigen::half, float> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 2,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<half2, float4>(const half2& a, const half2& b) {
- float2 r1 = __half22float2(a);
- float2 r2 = __half22float2(b);
- return make_float4(r1.x, r1.y, r2.x, r2.y);
-}
-
-template <>
-struct type_casting_traits<float, Eigen::half> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 2
- };
-};
-
-template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcast<float4, half2>(const float4& a) {
- // Simply discard the second half of the input
- return __floats2half2_rn(a.x, a.y);
-}
-
-#elif defined EIGEN_VECTORIZE_AVX512
-template <>
-struct type_casting_traits<half, float> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
- return half2float(a);
-}
-
-template <>
-struct type_casting_traits<float, half> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
- return float2half(a);
-}
-
-#elif defined EIGEN_VECTORIZE_AVX
-
-template <>
-struct type_casting_traits<Eigen::half, float> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
- return half2float(a);
-}
-
-template <>
-struct type_casting_traits<float, Eigen::half> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
- return float2half(a);
-}
-
-// Disable the following code since it's broken on too many platforms / compilers.
-//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
-#elif 0
-
-template <>
-struct type_casting_traits<Eigen::half, float> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4h, Packet4f>(const Packet4h& a) {
- __int64_t a64 = _mm_cvtm64_si64(a.x);
- Eigen::half h = raw_uint16_to_half(static_cast<unsigned short>(a64));
- float f1 = static_cast<float>(h);
- h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
- float f2 = static_cast<float>(h);
- h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
- float f3 = static_cast<float>(h);
- h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
- float f4 = static_cast<float>(h);
- return _mm_set_ps(f4, f3, f2, f1);
-}
-
-template <>
-struct type_casting_traits<float, Eigen::half> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet4h pcast<Packet4f, Packet4h>(const Packet4f& a) {
- EIGEN_ALIGN16 float aux[4];
- pstore(aux, a);
- Eigen::half h0(aux[0]);
- Eigen::half h1(aux[1]);
- Eigen::half h2(aux[2]);
- Eigen::half h3(aux[3]);
-
- Packet4h result;
- result.x = _mm_set_pi16(h3.x, h2.x, h1.x, h0.x);
- return result;
-}
-
-#endif
-
-} // end namespace internal
-
-} // end namespace Eigen
-
-#endif // EIGEN_TYPE_CASTING_CUDA_H
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
new file mode 100644
index 000000000..1c28f4f95
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -0,0 +1,700 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef EIGEN_BFLOAT16_H
+#define EIGEN_BFLOAT16_H
+
+#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
+ template <> \
+ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
+ PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
+ return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
+ }
+
+namespace Eigen {
+
+struct bfloat16;
+
+namespace bfloat16_impl {
+
+// Make our own __bfloat16_raw definition.
+struct __bfloat16_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
+ unsigned short value;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
+template <bool AssumeArgumentIsNormalOrInfinityOrZero>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
+// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
+// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
+
+struct bfloat16_base : public __bfloat16_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
+};
+
+} // namespace bfloat16_impl
+
+// Class definition.
+struct bfloat16 : public bfloat16_impl::bfloat16_base {
+
+ typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
+
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
+
+ template<class T>
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
+
+ explicit EIGEN_DEVICE_FUNC bfloat16(float f)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
+
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ template<typename RealScalar>
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
+
+ EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
+ return bfloat16_impl::bfloat16_to_float(*this);
+ }
+};
+} // namespace Eigen
+
+namespace std {
+template<>
+struct numeric_limits<Eigen::bfloat16> {
+ static const bool is_specialized = true;
+ static const bool is_signed = true;
+ static const bool is_integer = false;
+ static const bool is_exact = false;
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const float_denorm_style has_denorm = std::denorm_absent;
+ static const bool has_denorm_loss = false;
+ static const std::float_round_style round_style = numeric_limits<float>::round_style;
+ static const bool is_iec559 = false;
+ static const bool is_bounded = true;
+ static const bool is_modulo = false;
+ static const int digits = 8;
+ static const int digits10 = 2;
+ static const int max_digits10 = 4;
+ static const int radix = 2;
+ static const int min_exponent = numeric_limits<float>::min_exponent;
+ static const int min_exponent10 = numeric_limits<float>::min_exponent10;
+ static const int max_exponent = numeric_limits<float>::max_exponent;
+ static const int max_exponent10 = numeric_limits<float>::max_exponent10;
+ static const bool traps = numeric_limits<float>::traps;
+ static const bool tinyness_before = numeric_limits<float>::tinyness_before;
+
+ static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
+ static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
+ static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
+ static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
+ static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
+ static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
+ static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
+ static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
+ static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
+};
+
+// If std::numeric_limits<T> is specialized, should also specialize
+// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
+// std::numeric_limits<const volatile T>
+// https://stackoverflow.com/a/16519653/
+template<>
+struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+template<>
+struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
+} // namespace std
+
+namespace Eigen {
+
+namespace bfloat16_impl {
+
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+// We need to provide emulated *host-side* BF16 operators for clang.
+#pragma push_macro("EIGEN_DEVICE_FUNC")
+#undef EIGEN_DEVICE_FUNC
+#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
+#define EIGEN_DEVICE_FUNC __host__
+#else // both host and device need emulated ops.
+#define EIGEN_DEVICE_FUNC __host__ __device__
+#endif
+#endif
+
+// Definitions for CPUs, mostly working through conversion
+// to/from fp32.
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
+ return bfloat16(float(a) + static_cast<float>(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
+ return bfloat16(static_cast<float>(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) * float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) - float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
+ return bfloat16(float(a) / float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value ^ 0x8000;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) + float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) * float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) - float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
+ a = bfloat16(float(a) / float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
+ a += bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
+ a -= bfloat16(1);
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ ++a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
+ bfloat16 original_value = a;
+ --a;
+ return original_value;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
+ return numext::equal_strict(float(a),float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
+ return numext::not_equal_strict(float(a), float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
+ return float(a) < float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
+ return float(a) <= float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
+ return float(a) > float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
+ return float(a) >= float(b);
+}
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+#pragma pop_macro("EIGEN_DEVICE_FUNC")
+#endif
+#endif // Emulate support for bfloat16 floats
+
+// Division by an index. Do it in full float precision to avoid accuracy
+// issues in converting the denominator to bfloat16.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
+ __bfloat16_raw output;
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
+ output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
+ return output;
+ }
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ output.value = p[0];
+#else
+ output.value = p[1];
+#endif
+ return output;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
+ return __bfloat16_raw(value);
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
+ return bf.value;
+}
+
+// float_to_bfloat16_rtne template specialization that does not make any
+// assumption about the value of its function argument (ff).
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ __bfloat16_raw output;
+
+ if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+ // this makes sure after truncation we don't end up with an inf.
+ //
+ // qNaN magic: All exponent bits set + most significant bit of fraction
+ // set.
+ output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
+ } else {
+ // Fast rounding algorithm that rounds a half value to nearest even. This
+ // reduces expected error when we convert a large number of floats. Here
+ // is how it works:
+ //
+ // Definitions:
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+ // with the following tags:
+ //
+ // Sign | Exp (8 bits) | Frac (23 bits)
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
+ //
+ // S: Sign bit.
+ // E: Exponent bits.
+ // F: First 6 bits of fraction.
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
+ // rest of the float32. This is also the 7th bit of fraction
+ // R: Rounding bit, 8th bit of fraction.
+ // T: Sticky bits, rest of fraction, 15 bits.
+ //
+ // To round half to nearest even, there are 3 cases where we want to round
+ // down (simply truncate the result of the bits away, which consists of
+ // rounding bit and sticky bits) and two cases where we want to round up
+ // (truncate then add one to the result).
+ //
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
+ // truncates the last 16 bits away.
+ //
+ // To understand how it works, we can analyze this algorithm case by case:
+ //
+ // 1. L = 0, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input may create any carry, depending on
+ // whether there is any value set to 1 in T bits.
+ // - R may be set to 1 if there is a carry.
+ // - L remains 0.
+ // - Note that this case also handles Inf and -Inf, where all fraction
+ // bits, including L, R and Ts are all 0. The output remains Inf after
+ // this algorithm.
+ //
+ // 2. L = 1, R = 0:
+ // Expect: round down, this is less than half value.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits but
+ // adds 1 to rounding bit.
+ // - L remains 1.
+ //
+ // 3. L = 0, R = 1, all of T are 0:
+ // Expect: round down, this is exactly at half, the result is already
+ // even (L=0).
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input sets all sticky bits to 1, but
+ // doesn't create a carry.
+ // - R remains 1.
+ // - L remains 0.
+ //
+ // 4. L = 1, R = 1:
+ // Expect: round up, this is exactly at half, the result needs to be
+ // round to the next even number.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 1 = 0x8000
+ // - Adding rounding bias to input doesn't change sticky bits, but
+ // creates a carry from rounding bit.
+ // - The carry sets L to 0, creates another carry bit and propagate
+ // forward to F bits.
+ // - If all the F bits are 1, a carry then propagates to the exponent
+ // bits, which then creates the minimum value with the next exponent
+ // value. Note that we won't have the case where exponents are all 1,
+ // since that's either a NaN (handled in the other if condition) or inf
+ // (handled in case 1).
+ //
+ // 5. L = 0, R = 1, any of T is 1:
+ // Expect: round up, this is greater than half.
+ //
+ // Algorithm:
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
+ // - Adding rounding bias to input creates a carry from sticky bits,
+ // sets rounding bit to 0, then create another carry.
+ // - The second carry sets L to 1.
+ //
+ // Examples:
+ //
+ // Exact half value that is already even:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
+ //
+ // This falls into case 3. We truncate the rest of 16 bits and no
+ // carry is created into F and L:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ // Exact half value, round to next even number:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // which then propagates into L and F:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
+ //
+ //
+ // Max denormal value round to min normal value:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Output:
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
+ //
+ // Max normal value round to Inf:
+ // Input:
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
+ //
+ // This falls into case 4. We create a carry from R and T,
+ // propagate into L and F, which then propagates into exponent
+ // bits:
+ //
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
+ // S E E E E E E E E F F F F F F L
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
+
+ // At this point, ff must be either a normal float, or +/-infinity.
+ output = float_to_bfloat16_rtne<true>(ff);
+ }
+ return output;
+#endif
+}
+
+// float_to_bfloat16_rtne template specialization that assumes that its function
+// argument (ff) is either a normal floating point number, or +/-infinity, or
+// zero. Used to improve the runtime performance of conversion from an integer
+// type to bfloat16.
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
+ __bfloat16_raw output;
+
+ // Least significant bit of resulting bfloat.
+ numext::uint32_t lsb = (input >> 16) & 1;
+ numext::uint32_t rounding_bias = 0x7fff + lsb;
+ input += rounding_bias;
+ output.value = static_cast<numext::uint16_t>(input >> 16);
+ return output;
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
+ float result = 0;
+ unsigned short* q = reinterpret_cast<unsigned short*>(&result);
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ q[0] = h.value;
+#else
+ q[1] = h.value;
+#endif
+ return result;
+}
+// --- standard functions ---
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
+ EIGEN_USING_STD(isinf);
+ return (isinf)(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
+ EIGEN_USING_STD(isnan);
+ return (isnan)(float(a));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
+ return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
+ bfloat16 result;
+ result.value = a.value & 0x7FFF;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
+ return bfloat16(::expf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
+ return bfloat16(numext::expm1(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
+ return bfloat16(::logf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
+ return bfloat16(numext::log1p(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
+ return bfloat16(::log10f(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
+ return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
+ return bfloat16(::sqrtf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::powf(float(a), float(b)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
+ return bfloat16(::sinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
+ return bfloat16(::cosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
+ return bfloat16(::tanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
+ return bfloat16(::asinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
+ return bfloat16(::acosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
+ return bfloat16(::atanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
+ return bfloat16(::sinhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
+ return bfloat16(::coshf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
+ return bfloat16(::tanhf(float(a)));
+}
+#if EIGEN_HAS_CXX11_MATH
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
+ return bfloat16(::asinhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
+ return bfloat16(::acoshf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
+ return bfloat16(::atanhf(float(a)));
+}
+#endif
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
+ return bfloat16(::floorf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
+ return bfloat16(::ceilf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
+ return bfloat16(::rintf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
+ return bfloat16(::roundf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
+ return bfloat16(::fmodf(float(a), float(b)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fminf(f1, f2));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fmaxf(f1, f2));
+}
+
+#ifndef EIGEN_NO_IO
+EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+#endif
+
+} // namespace bfloat16_impl
+
+namespace internal {
+
+template<>
+struct random_default_impl<bfloat16, false, false>
+{
+ static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
+ {
+ return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
+ }
+ static inline bfloat16 run()
+ {
+ return run(bfloat16(-1.f), bfloat16(1.f));
+ }
+};
+
+template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
+
+} // namespace internal
+
+template<> struct NumTraits<Eigen::bfloat16>
+ : GenericNumTraits<Eigen::bfloat16>
+{
+ enum {
+ IsSigned = true,
+ IsInteger = false,
+ IsComplex = false,
+ RequireInitialization = false
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
+
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
+ return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
+ }
+};
+
+} // namespace Eigen
+
+namespace Eigen {
+namespace numext {
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isnan)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isnan)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isinf)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isinf)(h);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+bool (isfinite)(const Eigen::bfloat16& h) {
+ return (bfloat16_impl::isfinite)(h);
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
+ return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
+ return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
+}
+
+} // namespace numext
+} // namespace Eigen
+
+#if EIGEN_HAS_STD_HASH
+namespace std {
+template <>
+struct hash<Eigen::bfloat16> {
+ EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
+ return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
+ }
+};
+} // namespace std
+#endif
+
+
+#endif // EIGEN_BFLOAT16_H
diff --git a/Eigen/src/Core/arch/Default/ConjHelper.h b/Eigen/src/Core/arch/Default/ConjHelper.h
new file mode 100644
index 000000000..53830b5a2
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/ConjHelper.h
@@ -0,0 +1,117 @@
+
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_ARCH_CONJ_HELPER_H
+#define EIGEN_ARCH_CONJ_HELPER_H
+
+#define EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(PACKET_CPLX, PACKET_REAL) \
+ template <> \
+ struct conj_helper<PACKET_REAL, PACKET_CPLX, false, false> { \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_REAL& x, \
+ const PACKET_CPLX& y, \
+ const PACKET_CPLX& c) const { \
+ return padd(c, this->pmul(x, y)); \
+ } \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_REAL& x, \
+ const PACKET_CPLX& y) const { \
+ return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x, y.v)); \
+ } \
+ }; \
+ \
+ template <> \
+ struct conj_helper<PACKET_CPLX, PACKET_REAL, false, false> { \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_CPLX& x, \
+ const PACKET_REAL& y, \
+ const PACKET_CPLX& c) const { \
+ return padd(c, this->pmul(x, y)); \
+ } \
+ EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_CPLX& x, \
+ const PACKET_REAL& y) const { \
+ return PACKET_CPLX(Eigen::internal::pmul<PACKET_REAL>(x.v, y)); \
+ } \
+ };
+
+namespace Eigen {
+namespace internal {
+
+template<bool Conjugate> struct conj_if;
+
+template<> struct conj_if<true> {
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return numext::conj(x); }
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T pconj(const T& x) const { return internal::pconj(x); }
+};
+
+template<> struct conj_if<false> {
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator()(const T& x) const { return x; }
+ template<typename T>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; }
+};
+
+// Generic Implementation, assume scalars since the packet-version is
+// specialized below.
+template<typename LhsType, typename RhsType, bool ConjLhs, bool ConjRhs>
+struct conj_helper {
+ typedef typename ScalarBinaryOpTraits<LhsType, RhsType>::ReturnType ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsType& x, const RhsType& y) const
+ { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
+};
+
+template<typename LhsScalar, typename RhsScalar>
+struct conj_helper<LhsScalar, RhsScalar, true, true> {
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmadd(const LhsScalar& x, const RhsScalar& y, const ResultType& c) const
+ { return this->pmul(x, y) + c; }
+
+ // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType
+ pmul(const LhsScalar& x, const RhsScalar& y) const
+ { return numext::conj(x * y); }
+};
+
+// Implementation with equal type, use packet operations.
+template<typename Packet, bool ConjLhs, bool ConjRhs>
+struct conj_helper<Packet, Packet, ConjLhs, ConjRhs>
+{
+ typedef Packet ResultType;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
+
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
+};
+
+template<typename Packet>
+struct conj_helper<Packet, Packet, true, true>
+{
+ typedef Packet ResultType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const
+ { return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
+ // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const
+ { return pconj(Eigen::internal::pmul(x, y)); }
+};
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_ARCH_CONJ_HELPER_H
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
new file mode 100644
index 000000000..c9fbaf68b
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -0,0 +1,1649 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2007 Julien Pommier
+// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
+// Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/* The exp and log functions of this file initially come from
+ * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
+ */
+
+#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
+#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
+
+namespace Eigen {
+namespace internal {
+
+// Creates a Scalar integer type with same bit-width.
+template<typename T> struct make_integer;
+template<> struct make_integer<float> { typedef numext::int32_t type; };
+template<> struct make_integer<double> { typedef numext::int64_t type; };
+template<> struct make_integer<half> { typedef numext::int16_t type; };
+template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
+
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
+ return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
+}
+
+// Safely applies frexp, correctly handles denormals.
+// Assumes IEEE floating point format.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
+ ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
+ const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
+ const Packet half = pset1<Packet>(Scalar(0.5));
+ const Packet zero = pzero(a);
+ const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
+
+ // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
+ const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
+ EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
+ // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
+ const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
+ const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
+ const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
+
+ // Determine exponent offset: -126 if normal, -126-24 if denormal
+ const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
+ Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
+ const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
+ exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
+
+ // Determine exponent and mantissa from normalized_a.
+ exponent = pfrexp_generic_get_biased_exponent(normalized_a);
+ // Zero, Inf and NaN return 'a' unmodified, exponent is zero
+ // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
+ const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
+ const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
+ const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
+ const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
+ exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
+ return m;
+}
+
+// Safely applies ldexp, correctly handles overflows, underflows and denormals.
+// Assumes IEEE floating point format.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent) {
+ // We want to return a * 2^exponent, allowing for all possible integer
+ // exponents without overflowing or underflowing in intermediate
+ // computations.
+ //
+ // Since 'a' and the output can be denormal, the maximum range of 'exponent'
+ // to consider for a float is:
+ // -255-23 -> 255+23
+ // Below -278 any finite float 'a' will become zero, and above +278 any
+ // finite float will become inf, including when 'a' is the smallest possible
+ // denormal.
+ //
+ // Unfortunately, 2^(278) cannot be represented using either one or two
+ // finite normal floats, so we must split the scale factor into at least
+ // three parts. It turns out to be faster to split 'exponent' into four
+ // factors, since [exponent>>2] is much faster to compute that [exponent/3].
+ //
+ // Set e = min(max(exponent, -278), 278);
+ // b = floor(e/4);
+ // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
+ //
+ // This will avoid any intermediate overflows and correctly handle 0, inf,
+ // NaN cases.
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+ PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
+ Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
+ Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
+ out = pmul(out, c);
+ return out;
+}
+
+// Explicitly multiplies
+// a * (2^e)
+// clamping e to the range
+// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
+//
+// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
+// if 2^e doesn't fit into a normal floating-point Scalar.
+//
+// Assumes IEEE floating point format
+template<typename Packet>
+struct pldexp_fast_impl {
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+ Packet run(const Packet& a, const Packet& exponent) {
+ const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
+ const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
+ // restrict biased exponent between 0 and 255 for float.
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
+ // return a * (2^e)
+ return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
+ }
+};
+
+// Natural or base 2 logarithm.
+// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
+// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
+// be easily approximated by a polynomial centered on m=1 for stability.
+// TODO(gonnet): Further reduce the interval allowing for lower-degree
+// polynomial interpolants -> ... -> profit!
+template <typename Packet, bool base2>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_impl_float(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0f);
+ const Packet cst_neg_half = pset1<Packet>(-0.5f);
+ // The smallest non denormalized float number.
+ const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
+ const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
+ const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u);
+
+ // Polynomial coefficients.
+ const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
+ const Packet cst_cephes_log_p0 = pset1<Packet>(7.0376836292E-2f);
+ const Packet cst_cephes_log_p1 = pset1<Packet>(-1.1514610310E-1f);
+ const Packet cst_cephes_log_p2 = pset1<Packet>(1.1676998740E-1f);
+ const Packet cst_cephes_log_p3 = pset1<Packet>(-1.2420140846E-1f);
+ const Packet cst_cephes_log_p4 = pset1<Packet>(+1.4249322787E-1f);
+ const Packet cst_cephes_log_p5 = pset1<Packet>(-1.6668057665E-1f);
+ const Packet cst_cephes_log_p6 = pset1<Packet>(+2.0000714765E-1f);
+ const Packet cst_cephes_log_p7 = pset1<Packet>(-2.4999993993E-1f);
+ const Packet cst_cephes_log_p8 = pset1<Packet>(+3.3333331174E-1f);
+
+ // Truncate input values to the minimum positive normal.
+ x = pmax(x, cst_min_norm_pos);
+
+ Packet e;
+ // extract significant in the range [0.5,1) and exponent
+ x = pfrexp(x,e);
+
+ // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
+ // and shift by -1. The values are then centered around 0, which improves
+ // the stability of the polynomial evaluation.
+ // if( x < SQRTHF ) {
+ // e -= 1;
+ // x = x + x - 1.0;
+ // } else { x = x - 1.0; }
+ Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
+ Packet tmp = pand(x, mask);
+ x = psub(x, cst_1);
+ e = psub(e, pand(cst_1, mask));
+ x = padd(x, tmp);
+
+ Packet x2 = pmul(x, x);
+ Packet x3 = pmul(x2, x);
+
+ // Evaluate the polynomial approximant of degree 8 in three parts, probably
+ // to improve instruction-level parallelism.
+ Packet y, y1, y2;
+ y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
+ y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
+ y2 = pmadd(cst_cephes_log_p6, x, cst_cephes_log_p7);
+ y = pmadd(y, x, cst_cephes_log_p2);
+ y1 = pmadd(y1, x, cst_cephes_log_p5);
+ y2 = pmadd(y2, x, cst_cephes_log_p8);
+ y = pmadd(y, x3, y1);
+ y = pmadd(y, x3, y2);
+ y = pmul(y, x3);
+
+ y = pmadd(cst_neg_half, x2, y);
+ x = padd(x, y);
+
+ // Add the logarithm of the exponent back to the result of the interpolation.
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E));
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
+ x = pmadd(e, cst_ln2, x);
+ }
+
+ Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
+ Packet iszero_mask = pcmp_eq(_x,pzero(_x));
+ Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
+ // Filter out invalid inputs, i.e.:
+ // - negative arg will be NAN
+ // - 0 will be -INF
+ // - +INF will be +INF
+ return pselect(iszero_mask, cst_minus_inf,
+ por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ true>(_x);
+}
+
+/* Returns the base e (2.718...) or base 2 logarithm of x.
+ * The argument is separated into its exponent and fractional parts.
+ * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
+ * is approximated by
+ *
+ * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x).
+ *
+ * for more detail see: http://www.netlib.org/cephes/
+ */
+template <typename Packet, bool base2>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_impl_double(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0);
+ const Packet cst_neg_half = pset1<Packet>(-0.5);
+ // The smallest non denormalized double.
+ const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull));
+ const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull));
+ const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull));
+
+
+ // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x)
+ // 1/sqrt(2) <= x < sqrt(2)
+ const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
+ const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
+ const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
+ const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
+ const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
+ const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
+ const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
+
+ const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
+ const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
+ const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
+ const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
+ const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
+ const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);
+
+ // Truncate input values to the minimum positive normal.
+ x = pmax(x, cst_min_norm_pos);
+
+ Packet e;
+ // extract significant in the range [0.5,1) and exponent
+ x = pfrexp(x,e);
+
+ // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
+ // and shift by -1. The values are then centered around 0, which improves
+ // the stability of the polynomial evaluation.
+ // if( x < SQRTHF ) {
+ // e -= 1;
+ // x = x + x - 1.0;
+ // } else { x = x - 1.0; }
+ Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
+ Packet tmp = pand(x, mask);
+ x = psub(x, cst_1);
+ e = psub(e, pand(cst_1, mask));
+ x = padd(x, tmp);
+
+ Packet x2 = pmul(x, x);
+ Packet x3 = pmul(x2, x);
+
+ // Evaluate the polynomial approximant , probably to improve instruction-level parallelism.
+ // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
+ Packet y, y1, y_;
+ y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
+ y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
+ y = pmadd(y, x, cst_cephes_log_p2);
+ y1 = pmadd(y1, x, cst_cephes_log_p5);
+ y_ = pmadd(y, x3, y1);
+
+ y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1);
+ y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
+ y = pmadd(y, x, cst_cephes_log_q2);
+ y1 = pmadd(y1, x, cst_cephes_log_q5);
+ y = pmadd(y, x3, y1);
+
+ y_ = pmul(y_, x3);
+ y = pdiv(y_, y);
+
+ y = pmadd(cst_neg_half, x2, y);
+ x = padd(x, y);
+
+ // Add the logarithm of the exponent back to the result of the interpolation.
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
+ x = pmadd(e, cst_ln2, x);
+ }
+
+ Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
+ Packet iszero_mask = pcmp_eq(_x,pzero(_x));
+ Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
+ // Filter out invalid inputs, i.e.:
+ // - negative arg will be NAN
+ // - 0 will be -INF
+ // - +INF will be +INF
+ return pselect(iszero_mask, cst_minus_inf,
+ por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ true>(_x);
+}
+
+/** \internal \returns log(1 + x) computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_plog1p(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ Packet xp1 = padd(x, one);
+ Packet small_mask = pcmp_eq(xp1, one);
+ Packet log1 = plog(xp1);
+ Packet inf_mask = pcmp_eq(xp1, log1);
+ Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
+ return pselect(por(small_mask, inf_mask), x, log_large);
+}
+
+/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_expm1(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ const Packet neg_one = pset1<Packet>(ScalarType(-1));
+ Packet u = pexp(x);
+ Packet one_mask = pcmp_eq(u, one);
+ Packet u_minus_one = psub(u, one);
+ Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
+ Packet logu = plog(u);
+ // The following comparison is to catch the case where
+ // exp(x) = +inf. It is written in this way to avoid having
+ // to form the constant +inf, which depends on the packet
+ // type.
+ Packet pos_inf_mask = pcmp_eq(logu, u);
+ Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
+ expm1 = pselect(pos_inf_mask, u, expm1);
+ return pselect(one_mask,
+ x,
+ pselect(neg_one_mask,
+ neg_one,
+ expm1));
+}
+
+
+// Exponential function. Works by writing "x = m*log(2) + r" where
+// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
+// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_float(const Packet _x)
+{
+ const Packet cst_1 = pset1<Packet>(1.0f);
+ const Packet cst_half = pset1<Packet>(0.5f);
+ const Packet cst_exp_hi = pset1<Packet>( 88.723f);
+ const Packet cst_exp_lo = pset1<Packet>(-88.723f);
+
+ const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
+ const Packet cst_cephes_exp_p0 = pset1<Packet>(1.9875691500E-4f);
+ const Packet cst_cephes_exp_p1 = pset1<Packet>(1.3981999507E-3f);
+ const Packet cst_cephes_exp_p2 = pset1<Packet>(8.3334519073E-3f);
+ const Packet cst_cephes_exp_p3 = pset1<Packet>(4.1665795894E-2f);
+ const Packet cst_cephes_exp_p4 = pset1<Packet>(1.6666665459E-1f);
+ const Packet cst_cephes_exp_p5 = pset1<Packet>(5.0000001201E-1f);
+
+ // Clamp x.
+ Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo);
+
+ // Express exp(x) as exp(m*ln(2) + r), start by extracting
+ // m = floor(x/ln(2) + 0.5).
+ Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
+
+ // Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is
+ // subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating
+ // truncation errors.
+ const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
+ const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
+ Packet r = pmadd(m, cst_cephes_exp_C1, x);
+ r = pmadd(m, cst_cephes_exp_C2, r);
+
+ Packet r2 = pmul(r, r);
+ Packet r3 = pmul(r2, r);
+
+ // Evaluate the polynomial approximant,improved by instruction-level parallelism.
+ Packet y, y1, y2;
+ y = pmadd(cst_cephes_exp_p0, r, cst_cephes_exp_p1);
+ y1 = pmadd(cst_cephes_exp_p3, r, cst_cephes_exp_p4);
+ y2 = padd(r, cst_1);
+ y = pmadd(y, r, cst_cephes_exp_p2);
+ y1 = pmadd(y1, r, cst_cephes_exp_p5);
+ y = pmadd(y, r3, y1);
+ y = pmadd(y, r2, y2);
+
+ // Return 2^m * exp(r).
+ // TODO: replace pldexp with faster implementation since y in [-1, 1).
+ return pmax(pldexp(y,m), _x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_double(const Packet _x)
+{
+ Packet x = _x;
+
+ const Packet cst_1 = pset1<Packet>(1.0);
+ const Packet cst_2 = pset1<Packet>(2.0);
+ const Packet cst_half = pset1<Packet>(0.5);
+
+ const Packet cst_exp_hi = pset1<Packet>(709.784);
+ const Packet cst_exp_lo = pset1<Packet>(-709.784);
+
+ const Packet cst_cephes_LOG2EF = pset1<Packet>(1.4426950408889634073599);
+ const Packet cst_cephes_exp_p0 = pset1<Packet>(1.26177193074810590878e-4);
+ const Packet cst_cephes_exp_p1 = pset1<Packet>(3.02994407707441961300e-2);
+ const Packet cst_cephes_exp_p2 = pset1<Packet>(9.99999999999999999910e-1);
+ const Packet cst_cephes_exp_q0 = pset1<Packet>(3.00198505138664455042e-6);
+ const Packet cst_cephes_exp_q1 = pset1<Packet>(2.52448340349684104192e-3);
+ const Packet cst_cephes_exp_q2 = pset1<Packet>(2.27265548208155028766e-1);
+ const Packet cst_cephes_exp_q3 = pset1<Packet>(2.00000000000000000009e0);
+ const Packet cst_cephes_exp_C1 = pset1<Packet>(0.693145751953125);
+ const Packet cst_cephes_exp_C2 = pset1<Packet>(1.42860682030941723212e-6);
+
+ Packet tmp, fx;
+
+ // clamp x
+ x = pmax(pmin(x, cst_exp_hi), cst_exp_lo);
+ // Express exp(x) as exp(g + n*log(2)).
+ fx = pmadd(cst_cephes_LOG2EF, x, cst_half);
+
+ // Get the integer modulus of log(2), i.e. the "n" described above.
+ fx = pfloor(fx);
+
+ // Get the remainder modulo log(2), i.e. the "g" described above. Subtract
+ // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last
+ // digits right.
+ tmp = pmul(fx, cst_cephes_exp_C1);
+ Packet z = pmul(fx, cst_cephes_exp_C2);
+ x = psub(x, tmp);
+ x = psub(x, z);
+
+ Packet x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial of the rational interpolant.
+ Packet px = cst_cephes_exp_p0;
+ px = pmadd(px, x2, cst_cephes_exp_p1);
+ px = pmadd(px, x2, cst_cephes_exp_p2);
+ px = pmul(px, x);
+
+ // Evaluate the denominator polynomial of the rational interpolant.
+ Packet qx = cst_cephes_exp_q0;
+ qx = pmadd(qx, x2, cst_cephes_exp_q1);
+ qx = pmadd(qx, x2, cst_cephes_exp_q2);
+ qx = pmadd(qx, x2, cst_cephes_exp_q3);
+
+ // I don't really get this bit, copied from the SSE2 routines, so...
+ // TODO(gonnet): Figure out what is going on here, perhaps find a better
+ // rational interpolant?
+ x = pdiv(px, psub(qx, px));
+ x = pmadd(cst_2, x, cst_1);
+
+ // Construct the result 2^n * exp(g) = e * x. The max is used to catch
+ // non-finite values in the input.
+ // TODO: replace pldexp with faster implementation since x in [-1, 1).
+ return pmax(pldexp(x,fx), _x);
+}
+
+// The following code is inspired by the following stack-overflow answer:
+// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
+// It has been largely optimized:
+// - By-pass calls to frexp.
+// - Aligned loads of required 96 bits of 2/pi. This is accomplished by
+// (1) balancing the mantissa and exponent to the required bits of 2/pi are
+// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi.
+// - Avoid a branch in rounding and extraction of the remaining fractional part.
+// Overall, I measured a speed up higher than x2 on x86-64.
+inline float trig_reduce_huge (float xf, int *quadrant)
+{
+ using Eigen::numext::int32_t;
+ using Eigen::numext::uint32_t;
+ using Eigen::numext::int64_t;
+ using Eigen::numext::uint64_t;
+
+ const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62
+ const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point foramt
+
+ // 192 bits of 2/pi for Payne-Hanek reduction
+ // Bits are introduced by packet of 8 to enable aligned reads.
+ static const uint32_t two_over_pi [] =
+ {
+ 0x00000028, 0x000028be, 0x0028be60, 0x28be60db,
+ 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a,
+ 0x91054a7f, 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4,
+ 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770,
+ 0x4d377036, 0x377036d8, 0x7036d8a5, 0x36d8a566,
+ 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410,
+ 0x10e41000, 0xe4100000
+ };
+
+ uint32_t xi = numext::bit_cast<uint32_t>(xf);
+ // Below, -118 = -126 + 8.
+ // -126 is to get the exponent,
+ // +8 is to enable alignment of 2/pi's bits on 8 bits.
+ // This is possible because the fractional part of x as only 24 meaningful bits.
+ uint32_t e = (xi >> 23) - 118;
+ // Extract the mantissa and shift it to align it wrt the exponent
+ xi = ((xi & 0x007fffffu)| 0x00800000u) << (e & 0x7);
+
+ uint32_t i = e >> 3;
+ uint32_t twoopi_1 = two_over_pi[i-1];
+ uint32_t twoopi_2 = two_over_pi[i+3];
+ uint32_t twoopi_3 = two_over_pi[i+7];
+
+ // Compute x * 2/pi in 2.62-bit fixed-point format.
+ uint64_t p;
+ p = uint64_t(xi) * twoopi_3;
+ p = uint64_t(xi) * twoopi_2 + (p >> 32);
+ p = (uint64_t(xi * twoopi_1) << 32) + p;
+
+ // Round to nearest: add 0.5 and extract integral part.
+ uint64_t q = (p + zero_dot_five) >> 62;
+ *quadrant = int(q);
+ // Now it remains to compute "r = x - q*pi/2" with high accuracy,
+ // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as:
+ // r = (p-q)*pi/2,
+ // where the product can be be carried out with sufficient accuracy using double precision.
+ p -= q<<62;
+ return float(double(int64_t(p)) * pio2_62);
+}
+
+template<bool ComputeSine,typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+#if EIGEN_GNUC_AT_LEAST(4,4) && EIGEN_COMP_GNUC_STRICT
+__attribute__((optimize("-fno-unsafe-math-optimizations")))
+#endif
+Packet psincos_float(const Packet& _x)
+{
+ typedef typename unpacket_traits<Packet>::integer_packet PacketI;
+
+ const Packet cst_2oPI = pset1<Packet>(0.636619746685028076171875f); // 2/PI
+ const Packet cst_rounding_magic = pset1<Packet>(12582912); // 2^23 for rounding
+ const PacketI csti_1 = pset1<PacketI>(1);
+ const Packet cst_sign_mask = pset1frombits<Packet>(0x80000000u);
+
+ Packet x = pabs(_x);
+
+ // Scale x by 2/Pi to find x's octant.
+ Packet y = pmul(x, cst_2oPI);
+
+ // Rounding trick:
+ Packet y_round = padd(y, cst_rounding_magic);
+ EIGEN_OPTIMIZATION_BARRIER(y_round)
+ PacketI y_int = preinterpret<PacketI>(y_round); // last 23 digits represent integer (if abs(x)<2^24)
+ y = psub(y_round, cst_rounding_magic); // nearest integer to x*4/pi
+
+ // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4
+ // using "Extended precision modular arithmetic"
+ #if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD)
+ // This version requires true FMA for high accuracy
+ // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08):
+ const float huge_th = ComputeSine ? 117435.992f : 71476.0625f;
+ x = pmadd(y, pset1<Packet>(-1.57079601287841796875f), x);
+ x = pmadd(y, pset1<Packet>(-3.1391647326017846353352069854736328125e-07f), x);
+ x = pmadd(y, pset1<Packet>(-5.390302529957764765544681040410068817436695098876953125e-15f), x);
+ #else
+ // Without true FMA, the previous set of coefficients maintain 1ULP accuracy
+ // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7.
+ // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs.
+
+ // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively.
+ // and 2 ULP up to:
+ const float huge_th = ComputeSine ? 25966.f : 18838.f;
+ x = pmadd(y, pset1<Packet>(-1.5703125), x); // = 0xbfc90000
+ EIGEN_OPTIMIZATION_BARRIER(x)
+ x = pmadd(y, pset1<Packet>(-0.000483989715576171875), x); // = 0xb9fdc000
+ EIGEN_OPTIMIZATION_BARRIER(x)
+ x = pmadd(y, pset1<Packet>(1.62865035235881805419921875e-07), x); // = 0x342ee000
+ x = pmadd(y, pset1<Packet>(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee
+
+ // For the record, the following set of coefficients maintain 2ULP up
+ // to a slightly larger range:
+ // const float huge_th = ComputeSine ? 51981.f : 39086.125f;
+ // but it slightly fails to maintain 1ULP for two values of sin below pi.
+ // x = pmadd(y, pset1<Packet>(-3.140625/2.), x);
+ // x = pmadd(y, pset1<Packet>(-0.00048351287841796875), x);
+ // x = pmadd(y, pset1<Packet>(-3.13855707645416259765625e-07), x);
+ // x = pmadd(y, pset1<Packet>(-6.0771006282767103812147979624569416046142578125e-11), x);
+
+ // For the record, with only 3 iterations it is possible to maintain
+ // 1 ULP up to 3PI (maybe more) and 2ULP up to 255.
+ // The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee
+ #endif
+
+ if(predux_any(pcmp_le(pset1<Packet>(huge_th),pabs(_x))))
+ {
+ const int PacketSize = unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize];
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize];
+ EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) int y_int2[PacketSize];
+ pstoreu(vals, pabs(_x));
+ pstoreu(x_cpy, x);
+ pstoreu(y_int2, y_int);
+ for(int k=0; k<PacketSize;++k)
+ {
+ float val = vals[k];
+ if(val>=huge_th && (numext::isfinite)(val))
+ x_cpy[k] = trig_reduce_huge(val,&y_int2[k]);
+ }
+ x = ploadu<Packet>(x_cpy);
+ y_int = ploadu<PacketI>(y_int2);
+ }
+
+ // Compute the sign to apply to the polynomial.
+ // sin: sign = second_bit(y_int) xor signbit(_x)
+ // cos: sign = second_bit(y_int+1)
+ Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int)))
+ : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int,csti_1)));
+ sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
+
+ // Get the polynomial selection mask from the second bit of y_int
+ // We'll calculate both (sin and cos) polynomials and then select from the two.
+ Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(y_int, csti_1), pzero(y_int)));
+
+ Packet x2 = pmul(x,x);
+
+ // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4)
+ Packet y1 = pset1<Packet>(2.4372266125283204019069671630859375e-05f);
+ y1 = pmadd(y1, x2, pset1<Packet>(-0.00138865201734006404876708984375f ));
+ y1 = pmadd(y1, x2, pset1<Packet>(0.041666619479656219482421875f ));
+ y1 = pmadd(y1, x2, pset1<Packet>(-0.5f));
+ y1 = pmadd(y1, x2, pset1<Packet>(1.f));
+
+ // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4)
+ // octave/matlab code to compute those coefficients:
+ // x = (0:0.0001:pi/4)';
+ // A = [x.^3 x.^5 x.^7];
+ // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy
+ // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1
+ // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1))
+ //
+ Packet y2 = pset1<Packet>(-0.0001959234114083702898469196984621021329076029360294342041015625f);
+ y2 = pmadd(y2, x2, pset1<Packet>( 0.0083326873655616851693794799871284340042620897293090820312500000f));
+ y2 = pmadd(y2, x2, pset1<Packet>(-0.1666666203982298255503735617821803316473960876464843750000000000f));
+ y2 = pmul(y2, x2);
+ y2 = pmadd(y2, x, x);
+
+ // Select the correct result from the two polynomials.
+ y = ComputeSine ? pselect(poly_mask,y2,y1)
+ : pselect(poly_mask,y1,y2);
+
+ // Update the sign and filter huge inputs
+ return pxor(y, sign_bit);
+}
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psin_float(const Packet& x)
+{
+ return psincos_float<true>(x);
+}
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pcos_float(const Packet& x)
+{
+ return psincos_float<false>(x);
+}
+
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psqrt_complex(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename Scalar::value_type RealScalar;
+ typedef typename unpacket_traits<Packet>::as_real RealPacket;
+
+ // Computes the principal sqrt of the complex numbers in the input.
+ //
+ // For example, for packets containing 2 complex numbers stored in interleaved format
+ // a = [a0, a1] = [x0, y0, x1, y1],
+ // where x0 = real(a0), y0 = imag(a0) etc., this function returns
+ // b = [b0, b1] = [u0, v0, u1, v1],
+ // such that b0^2 = a0, b1^2 = a1.
+ //
+ // To derive the formula for the complex square roots, let's consider the equation for
+ // a single complex square root of the number x + i*y. We want to find real numbers
+ // u and v such that
+ // (u + i*v)^2 = x + i*y <=>
+ // u^2 - v^2 + i*2*u*v = x + i*v.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x
+ // 2*u*v = y.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
+ // v = 0.5 * (y / u)
+ // and for x < 0,
+ // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
+ // u = 0.5 * (y / v)
+ //
+ // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
+ // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
+
+ // In the following, without lack of generality, we have annotated the code, assuming
+ // that the input is a packet of 2 complex numbers.
+ //
+ // Step 1. Compute l = [l0, l0, l1, l1], where
+ // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2)
+ // To avoid over- and underflow, we use the stable formula for each hypotenuse
+ // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)),
+ // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1.
+
+ RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|]
+ RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|]
+ RealPacket a_max = pmax(a_abs, a_abs_flip);
+ RealPacket a_min = pmin(a_abs, a_abs_flip);
+ RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
+ RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
+ RealPacket r = pdiv(a_min, a_max);
+ const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
+ RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
+ // Set l to a_max if a_min is zero.
+ l = pselect(a_min_zero_mask, a_max, l);
+
+ // Step 2. Compute [rho0, *, rho1, *], where
+ // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|))
+ // We don't care about the imaginary parts computed here. They will be overwritten later.
+ const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5));
+ Packet rho;
+ rho.v = psqrt(pmul(cst_half, padd(a_abs, l)));
+
+ // Step 3. Compute [rho0, eta0, rho1, eta1], where
+ // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2.
+ // set eta = 0 of input is 0 + i0.
+ RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask);
+ RealPacket real_mask = peven_mask(a.v);
+ Packet positive_real_result;
+ // Compute result for inputs with positive real part.
+ positive_real_result.v = pselect(real_mask, rho.v, eta);
+
+ // Step 4. Compute solution for inputs with negative real part:
+ // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
+ const RealScalar neg_zero = RealScalar(numext::bit_cast<float>(0x80000000u));
+ const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), neg_zero)).v;
+ RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
+ Packet negative_real_result;
+ // Notice that rho is positive, so taking it's absolute value is a noop.
+ negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs);
+
+ // Step 5. Select solution branch based on the sign of the real parts.
+ Packet negative_real_mask;
+ negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v));
+ negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v);
+ Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result);
+
+ // Step 6. Handle special cases for infinities:
+ // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN
+ // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN
+ // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y
+ // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y
+ const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
+ Packet is_inf;
+ is_inf.v = pcmp_eq(a_abs, cst_pos_inf);
+ Packet is_real_inf;
+ is_real_inf.v = pand(is_inf.v, real_mask);
+ is_real_inf = por(is_real_inf, pcplxflip(is_real_inf));
+ // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part.
+ Packet real_inf_result;
+ real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v);
+ real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v);
+ // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part.
+ Packet is_imag_inf;
+ is_imag_inf.v = pandnot(is_inf.v, real_mask);
+ is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf));
+ Packet imag_inf_result;
+ imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
+
+ return pselect(is_imag_inf, imag_inf_result,
+ pselect(is_real_inf, real_inf_result,result));
+}
+
+// TODO(rmlarsen): The following set of utilities for double word arithmetic
+// should perhaps be refactored as a separate file, since it would be generally
+// useful for special function implementation etc. Writing the algorithms in
+// terms if a double word type would also make the code more readable.
+
+// This function splits x into the nearest integer n and fractional part r,
+// such that x = n + r holds exactly.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void absolute_split(const Packet& x, Packet& n, Packet& r) {
+ n = pround(x);
+ r = psub(x, n);
+}
+
+// This function computes the sum {s, r}, such that x + y = s_hi + s_lo
+// holds exactly, and s_hi = fl(x+y), if |x| >= |y|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) {
+ s_hi = padd(x, y);
+ const Packet t = psub(s_hi, x);
+ s_lo = psub(y, t);
+}
+
+#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+// This function implements the extended precision product of
+// a pair of floating point numbers. Given {x, y}, it computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x, y, pnegate(p_hi));
+}
+
+#else
+
+// This function implements the Veltkamp splitting. Given a floating point
+// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
+// exactly and that half of the significant of x fits in x_hi.
+// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
+ const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
+ const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
+ Packet rho = psub(x, gamma);
+ x_hi = padd(rho, gamma);
+ x_lo = psub(x, x_hi);
+}
+
+// This function implements Dekker's algorithm for products x * y.
+// Given floating point numbers {x, y} computes the pair
+// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
+// p_hi = fl(x * y).
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet x_hi, x_lo, y_hi, y_lo;
+ veltkamp_splitting(x, x_hi, x_lo);
+ veltkamp_splitting(y, y_hi, y_lo);
+
+ p_hi = pmul(x, y);
+ p_lo = pmadd(x_hi, y_hi, pnegate(p_hi));
+ p_lo = pmadd(x_hi, y_lo, p_lo);
+ p_lo = pmadd(x_lo, y_hi, p_lo);
+ p_lo = pmadd(x_lo, y_lo, p_lo);
+}
+
+#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+
+
+// This function implements Dekker's algorithm for the addition
+// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
+// It returns the result as a pair {s_hi, s_lo} such that
+// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly.
+// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+ void twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi));
+ Packet r_hi_1, r_lo_1;
+ fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1);
+ Packet r_hi_2, r_lo_2;
+ fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2);
+ const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2);
+
+ const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo);
+ const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo);
+ const Packet s = pselect(x_greater_mask, s1, s2);
+
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This is a version of twosum for double word numbers,
+// which assumes that |x_hi| >= |y_hi|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+ void fast_twosum(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ Packet r_hi, r_lo;
+ fast_twosum(x_hi, y_hi, r_hi, r_lo);
+ const Packet s = padd(padd(y_lo, r_lo), x_lo);
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This is a version of twosum for adding a floating point number x to
+// double word number {y_hi, y_lo} number, with the assumption
+// that |x| >= |y_hi|.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void fast_twosum(const Packet& x,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& s_hi, Packet& s_lo) {
+ Packet r_hi, r_lo;
+ fast_twosum(x, y_hi, r_hi, r_lo);
+ const Packet s = padd(y_lo, r_lo);
+ fast_twosum(r_hi, s, s_hi, s_lo);
+}
+
+// This function implements the multiplication of a double word
+// number represented by {x_hi, x_lo} by a floating point number y.
+// It returns the result as a pair {p_hi, p_lo} such that
+// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error
+// of less than 2*2^{-2p}, where p is the number of significand bit
+// in the floating point type.
+// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions",
+// 3rd edition, Birkh\"auser, 2016.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
+ Packet& p_hi, Packet& p_lo) {
+ Packet c_hi, c_lo1;
+ twoprod(x_hi, y, c_hi, c_lo1);
+ const Packet c_lo2 = pmul(x_lo, y);
+ Packet t_hi, t_lo1;
+ fast_twosum(c_hi, c_lo2, t_hi, t_lo1);
+ const Packet t_lo2 = padd(t_lo1, c_lo1);
+ fast_twosum(t_hi, t_lo2, p_hi, p_lo);
+}
+
+// This function implements the multiplication of two double word
+// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
+// It returns the result as a pair {p_hi, p_lo} such that
+// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error
+// of less than 2*2^{-2p}, where p is the number of significand bit
+// in the floating point type.
+template<typename Packet>
+EIGEN_STRONG_INLINE
+void twoprod(const Packet& x_hi, const Packet& x_lo,
+ const Packet& y_hi, const Packet& y_lo,
+ Packet& p_hi, Packet& p_lo) {
+ Packet p_hi_hi, p_hi_lo;
+ twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo);
+ Packet p_lo_hi, p_lo_lo;
+ twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo);
+ fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo);
+}
+
+// This function computes the reciprocal of a floating point number
+// with extra precision and returns the result as a double word.
+template <typename Packet>
+void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ // 1. Approximate the reciprocal as the reciprocal of the high order element.
+ Packet approx_recip = prsqrt(x);
+ approx_recip = pmul(approx_recip, approx_recip);
+
+ // 2. Run one step of Newton-Raphson iteration in double word arithmetic
+ // to get the bottom half. The NR iteration for reciprocal of 'a' is
+ // x_{i+1} = x_i * (2 - a * x_i)
+
+ // -a*x_i
+ Packet t1_hi, t1_lo;
+ twoprod(pnegate(x), approx_recip, t1_hi, t1_lo);
+ // 2 - a*x_i
+ Packet t2_hi, t2_lo;
+ fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo);
+ Packet t3_hi, t3_lo;
+ fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo);
+ // x_i * (2 - a * x_i)
+ twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo);
+}
+
+
+// This function computes log2(x) and returns the result as a double word.
+template <typename Scalar>
+struct accurate_log2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
+ log2_x_hi = plog2(x);
+ log2_x_lo = pzero(x);
+ }
+};
+
+// This specialization uses a more accurate algorithm to compute log2(x) for
+// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
+// This additional accuracy is needed to counter the error-magnification
+// inherent in multiplying by a potentially large exponent in pow(x,y).
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct accurate_log2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
+ // The function log(1+x)/x is approximated in the interval
+ // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
+ // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
+ // where the degree 6 polynomial P(x) is evaluated in single precision,
+ // while the remaining 4 terms of Q(x), as well as the final multiplication by x
+ // to reconstruct log(1+x) are evaluated in extra precision using
+ // double word arithmetic. C0 through C3 are extra precise constants
+ // stored as double words.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 10;
+ // > f = log2(1+x)/x;
+ // > interval = [sqrt(0.5)-1;sqrt(2)-1];
+ // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
+
+ const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
+ const Packet p5 = pset1<Packet>(-0.1690667718648f);
+ const Packet p4 = pset1<Packet>( 0.1720575392246f);
+ const Packet p3 = pset1<Packet>(-0.1789081543684f);
+ const Packet p2 = pset1<Packet>( 0.2050433009862f);
+ const Packet p1 = pset1<Packet>(-0.2404672354459f);
+ const Packet p0 = pset1<Packet>( 0.2885761857032f);
+
+ const Packet C3_hi = pset1<Packet>(-0.360674142838f);
+ const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
+ const Packet C2_hi = pset1<Packet>(0.480897903442f);
+ const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
+ const Packet C1_hi = pset1<Packet>(-0.721347510815f);
+ const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
+ const Packet C0_hi = pset1<Packet>(1.44269502163f);
+ const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ const Packet x = psub(z, one);
+ // Evaluate P(x) in working precision.
+ // We evaluate it in multiple parts to improve instruction level
+ // parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p6, x2, p4);
+ p_even = pmadd(p_even, x2, p2);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p_odd = pmadd(p5, x2, p3);
+ p_odd = pmadd(p_odd, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Now evaluate the low-order tems of Q(x) in double word precision.
+ // In the following, due to the alternating signs and the fact that
+ // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
+ // fast_twosum instead of the slower twosum.
+ Packet q_hi, q_lo;
+ Packet t_hi, t_lo;
+ // C3 + x * p(x)
+ twoprod(p, x, t_hi, t_lo);
+ fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
+ // C2 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
+ // C1 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
+ // C0 + x * p(x)
+ twoprod(q_hi, q_lo, x, t_hi, t_lo);
+ fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
+
+ // log(z) ~= x * Q(x)
+ twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
+ }
+};
+
+// This specialization uses a more accurate algorithm to compute log2(x) for
+// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18.
+// This additional accuracy is needed to counter the error-magnification
+// inherent in multiplying by a potentially large exponent in pow(x,y).
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+
+template <>
+struct accurate_log2<double> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
+ // We use a transformation of variables:
+ // r = c * (x-1) / (x+1),
+ // such that
+ // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r).
+ // The function f(r) can be approximated well using an odd polynomial
+ // of the form
+ // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r,
+ // For the implementation of log2<double> here, Q is of degree 6 with
+ // coefficient represented in working precision (double), while C is a
+ // constant represented in extra precision as a double word to achieve
+ // full accuracy.
+ //
+ // The polynomial coefficients were computed by the Sollya script:
+ //
+ // c = 2 / log(2);
+ // trans = c * (x-1)/(x+1);
+ // itrans = (1+x/c)/(1-x/c);
+ // interval=[trans(sqrt(0.5)); trans(sqrt(2))];
+ // print(interval);
+ // f = log2(itrans(x));
+ // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating);
+ const Packet q12 = pset1<Packet>(2.87074255468000586e-9);
+ const Packet q10 = pset1<Packet>(2.38957980901884082e-8);
+ const Packet q8 = pset1<Packet>(2.31032094540014656e-7);
+ const Packet q6 = pset1<Packet>(2.27279857398537278e-6);
+ const Packet q4 = pset1<Packet>(2.31271023278625638e-5);
+ const Packet q2 = pset1<Packet>(2.47556738444535513e-4);
+ const Packet q0 = pset1<Packet>(2.88543873228900172e-3);
+ const Packet C_hi = pset1<Packet>(0.0400377511598501157);
+ const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19);
+ const Packet one = pset1<Packet>(1.0);
+
+ const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677);
+ const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17);
+ // c * (x - 1)
+ Packet num_hi, num_lo;
+ twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo);
+ // TODO(rmlarsen): Investigate if using the division algorithm by
+ // Muller et al. is faster/more accurate.
+ // 1 / (x + 1)
+ Packet denom_hi, denom_lo;
+ doubleword_reciprocal(padd(x, one), denom_hi, denom_lo);
+ // r = c * (x-1) / (x+1),
+ Packet r_hi, r_lo;
+ twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo);
+ // r2 = r * r
+ Packet r2_hi, r2_lo;
+ twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo);
+ // r4 = r2 * r2
+ Packet r4_hi, r4_lo;
+ twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo);
+
+ // Evaluate Q(r^2) in working precision. We evaluate it in two parts
+ // (even and odd in r^2) to improve instruction level parallelism.
+ Packet q_even = pmadd(q12, r4_hi, q8);
+ Packet q_odd = pmadd(q10, r4_hi, q6);
+ q_even = pmadd(q_even, r4_hi, q4);
+ q_odd = pmadd(q_odd, r4_hi, q2);
+ q_even = pmadd(q_even, r4_hi, q0);
+ Packet q = pmadd(q_odd, r2_hi, q_even);
+
+ // Now evaluate the low order terms of P(x) in double word precision.
+ // In the following, due to the increasing magnitude of the coefficients
+ // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead
+ // of the slower twosum.
+ // Q(r^2) * r^2
+ Packet p_hi, p_lo;
+ twoprod(r2_hi, r2_lo, q, p_hi, p_lo);
+ // Q(r^2) * r^2 + C
+ Packet p1_hi, p1_lo;
+ fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo);
+ // (Q(r^2) * r^2 + C) * r^2
+ Packet p2_hi, p2_lo;
+ twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo);
+ // ((Q(r^2) * r^2 + C) * r^2 + 1)
+ Packet p3_hi, p3_lo;
+ fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo);
+
+ // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r
+ twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo);
+ }
+};
+
+// This function computes exp2(x) (i.e. 2**x).
+template <typename Scalar>
+struct fast_accurate_exp2 {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // TODO(rmlarsen): Add a pexp2 packetop.
+ return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x));
+ }
+};
+
+// This specialization uses a faster algorithm to compute exp2(x) for floats
+// in [-0.5;0.5] with a relative accuracy of 1 ulp.
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct fast_accurate_exp2<float> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // This function approximates exp2(x) by a degree 6 polynomial of the form
+ // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
+ // single precision, and the remaining steps are evaluated with extra precision using
+ // double word arithmetic. C is an extra precise constant stored as a double word.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 6;
+ // > f = 2^x;
+ // > interval = [-0.5;0.5];
+ // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
+
+ const Packet p4 = pset1<Packet>(1.539513905e-4f);
+ const Packet p3 = pset1<Packet>(1.340007293e-3f);
+ const Packet p2 = pset1<Packet>(9.618283249e-3f);
+ const Packet p1 = pset1<Packet>(5.550328270e-2f);
+ const Packet p0 = pset1<Packet>(0.2402264923f);
+
+ const Packet C_hi = pset1<Packet>(0.6931471825f);
+ const Packet C_lo = pset1<Packet>(2.36836577e-08f);
+ const Packet one = pset1<Packet>(1.0f);
+
+ // Evaluate P(x) in working precision.
+ // We evaluate even and odd parts of the polynomial separately
+ // to gain some instruction level parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p4, x2, p2);
+ Packet p_odd = pmadd(p3, x2, p1);
+ p_even = pmadd(p_even, x2, p0);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Evaluate the remaining terms of Q(x) with extra precision using
+ // double word arithmetic.
+ Packet p_hi, p_lo;
+ // x * p(x)
+ twoprod(p, x, p_hi, p_lo);
+ // C + x * p(x)
+ Packet q1_hi, q1_lo;
+ twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
+ // x * (C + x * p(x))
+ Packet q2_hi, q2_lo;
+ twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
+ // 1 + x * (C + x * p(x))
+ Packet q3_hi, q3_lo;
+ // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
+ // for adding it to unity here.
+ fast_twosum(one, q2_hi, q3_hi, q3_lo);
+ return padd(q3_hi, padd(q2_lo, q3_lo));
+ }
+};
+
+// in [-0.5;0.5] with a relative accuracy of 1 ulp.
+// The minimax polynomial used was calculated using the Sollya tool.
+// See sollya.org.
+template <>
+struct fast_accurate_exp2<double> {
+ template <typename Packet>
+ EIGEN_STRONG_INLINE
+ Packet operator()(const Packet& x) {
+ // This function approximates exp2(x) by a degree 10 polynomial of the form
+ // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
+ // single precision, and the remaining steps are evaluated with extra precision using
+ // double word arithmetic. C is an extra precise constant stored as a double word.
+ //
+ // The polynomial coefficients were calculated using Sollya commands:
+ // > n = 11;
+ // > f = 2^x;
+ // > interval = [-0.5;0.5];
+ // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
+
+ const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
+ const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
+ const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
+ const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
+ const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
+ const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
+ const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
+ const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
+ const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
+ const Packet p0 = pset1<Packet>(0.240226506959101332);
+ const Packet C_hi = pset1<Packet>(0.693147180559945286);
+ const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
+ const Packet one = pset1<Packet>(1.0);
+
+ // Evaluate P(x) in working precision.
+ // We evaluate even and odd parts of the polynomial separately
+ // to gain some instruction level parallelism.
+ Packet x2 = pmul(x,x);
+ Packet p_even = pmadd(p8, x2, p6);
+ Packet p_odd = pmadd(p9, x2, p7);
+ p_even = pmadd(p_even, x2, p4);
+ p_odd = pmadd(p_odd, x2, p5);
+ p_even = pmadd(p_even, x2, p2);
+ p_odd = pmadd(p_odd, x2, p3);
+ p_even = pmadd(p_even, x2, p0);
+ p_odd = pmadd(p_odd, x2, p1);
+ Packet p = pmadd(p_odd, x, p_even);
+
+ // Evaluate the remaining terms of Q(x) with extra precision using
+ // double word arithmetic.
+ Packet p_hi, p_lo;
+ // x * p(x)
+ twoprod(p, x, p_hi, p_lo);
+ // C + x * p(x)
+ Packet q1_hi, q1_lo;
+ twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
+ // x * (C + x * p(x))
+ Packet q2_hi, q2_lo;
+ twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
+ // 1 + x * (C + x * p(x))
+ Packet q3_hi, q3_lo;
+ // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
+ // for adding it to unity here.
+ fast_twosum(one, q2_hi, q3_hi, q3_lo);
+ return padd(q3_hi, padd(q2_lo, q3_lo));
+ }
+};
+
+// This function implements the non-trivial case of pow(x,y) where x is
+// positive and y is (possibly) non-integer.
+// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
+// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it
+// easier to specialize or turn off for specific types and/or backends.x
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ // Split x into exponent e_x and mantissa m_x.
+ Packet e_x;
+ Packet m_x = pfrexp(x, e_x);
+
+ // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x).
+ EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440);
+ const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half));
+ m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
+ e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
+
+ // Compute log2(m_x) with 6 extra bits of accuracy.
+ Packet rx_hi, rx_lo;
+ accurate_log2<Scalar>()(m_x, rx_hi, rx_lo);
+
+ // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
+ // precision using double word arithmetic.
+ Packet f1_hi, f1_lo, f2_hi, f2_lo;
+ twoprod(e_x, y, f1_hi, f1_lo);
+ twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo);
+ // Sum the two terms in f using double word arithmetic. We know
+ // that |e_x| > |log2(m_x)|, except for the case where e_x==0.
+ // This means that we can use fast_twosum(f1,f2).
+ // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any
+ // accuracy by violating the assumption of fast_twosum, because
+ // it's a no-op.
+ Packet f_hi, f_lo;
+ fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo);
+
+ // Split f into integer and fractional parts.
+ Packet n_z, r_z;
+ absolute_split(f_hi, n_z, r_z);
+ r_z = padd(r_z, f_lo);
+ Packet n_r;
+ absolute_split(r_z, n_r, r_z);
+ n_z = padd(n_z, n_r);
+
+ // We now have an accurate split of f = n_z + r_z and can compute
+ // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
+ // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
+ // using a specialized algorithm. Multiplication by the second factor can
+ // be done exactly using pldexp(), since it is an integer power of 2.
+ const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
+ return pldexp(e_r, n_z);
+}
+
+// Generic implementation of pow(x,y).
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet generic_pow(const Packet& x, const Packet& y) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+
+ const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
+ const Packet cst_zero = pset1<Packet>(Scalar(0));
+ const Packet cst_one = pset1<Packet>(Scalar(1));
+ const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
+
+ const Packet abs_x = pabs(x);
+ // Predicates for sign and magnitude of x.
+ const Packet x_is_zero = pcmp_eq(x, cst_zero);
+ const Packet x_is_neg = pcmp_lt(x, cst_zero);
+ const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
+ const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
+ const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
+ const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
+ const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
+ const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
+ const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
+
+ // Predicates for sign and magnitude of y.
+ const Packet y_is_one = pcmp_eq(y, cst_one);
+ const Packet y_is_zero = pcmp_eq(y, cst_zero);
+ const Packet y_is_neg = pcmp_lt(y, cst_zero);
+ const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
+ const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
+ const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
+ EIGEN_CONSTEXPR Scalar huge_exponent =
+ (NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) /
+ NumTraits<Scalar>::epsilon();
+ const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
+
+ // Predicates for whether y is integer and/or even.
+ const Packet y_is_int = pcmp_eq(pfloor(y), y);
+ const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
+ const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
+
+ // Predicates encoding special cases for the value of pow(x,y)
+ const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf),
+ y_is_int),
+ abs_y_is_inf);
+ const Packet pow_is_one = por(por(x_is_one, y_is_zero),
+ pand(x_is_neg_one,
+ por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
+ const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
+ const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos),
+ pand(abs_x_is_inf, y_is_neg)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_pos)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_neg));
+ const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg),
+ pand(abs_x_is_inf, y_is_pos)),
+ pand(pand(abs_x_is_lt_one, abs_y_is_huge),
+ y_is_neg)),
+ pand(pand(abs_x_is_gt_one, abs_y_is_huge),
+ y_is_pos));
+
+ // General computation of pow(x,y) for positive x or negative x and integer y.
+ const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
+ const Packet pow_abs = generic_pow_impl(abs_x, y);
+ return pselect(y_is_one, x,
+ pselect(pow_is_one, cst_one,
+ pselect(pow_is_nan, cst_nan,
+ pselect(pow_is_inf, cst_pos_inf,
+ pselect(pow_is_zero, cst_zero,
+ pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
+}
+
+
+
+/* polevl (modified for Eigen)
+ *
+ * Evaluate polynomial
+ *
+ *
+ *
+ * SYNOPSIS:
+ *
+ * int N;
+ * Scalar x, y, coef[N+1];
+ *
+ * y = polevl<decltype(x), N>( x, coef);
+ *
+ *
+ *
+ * DESCRIPTION:
+ *
+ * Evaluates polynomial of degree N:
+ *
+ * 2 N
+ * y = C + C x + C x +...+ C x
+ * 0 1 2 N
+ *
+ * Coefficients are stored in reverse order:
+ *
+ * coef[0] = C , ..., coef[N] = C .
+ * N 0
+ *
+ * The function p1evl() assumes that coef[N] = 1.0 and is
+ * omitted from the array. Its calling arguments are
+ * otherwise the same as polevl().
+ *
+ *
+ * The Eigen implementation is templatized. For best speed, store
+ * coef as a const array (constexpr), e.g.
+ *
+ * const double coef[] = {1.0, 2.0, 3.0, ...};
+ *
+ */
+template <typename Packet, int N>
+struct ppolevl {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
+ EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return pmadd(ppolevl<Packet, N-1>::run(x, coeff), x, pset1<Packet>(coeff[N]));
+ }
+};
+
+template <typename Packet>
+struct ppolevl<Packet, 0> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
+ EIGEN_UNUSED_VARIABLE(x);
+ return pset1<Packet>(coeff[0]);
+ }
+};
+
+/* chbevl (modified for Eigen)
+ *
+ * Evaluate Chebyshev series
+ *
+ *
+ *
+ * SYNOPSIS:
+ *
+ * int N;
+ * Scalar x, y, coef[N], chebevl();
+ *
+ * y = chbevl( x, coef, N );
+ *
+ *
+ *
+ * DESCRIPTION:
+ *
+ * Evaluates the series
+ *
+ * N-1
+ * - '
+ * y = > coef[i] T (x/2)
+ * - i
+ * i=0
+ *
+ * of Chebyshev polynomials Ti at argument x/2.
+ *
+ * Coefficients are stored in reverse order, i.e. the zero
+ * order term is last in the array. Note N is the number of
+ * coefficients, not the order.
+ *
+ * If coefficients are for the interval a to b, x must
+ * have been transformed to x -> 2(2x - b - a)/(b-a) before
+ * entering the routine. This maps x from (a, b) to (-1, 1),
+ * over which the Chebyshev polynomials are defined.
+ *
+ * If the coefficients are for the inverted interval, in
+ * which (a, b) is mapped to (1/b, 1/a), the transformation
+ * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity,
+ * this becomes x -> 4a/x - 1.
+ *
+ *
+ *
+ * SPEED:
+ *
+ * Taking advantage of the recurrence properties of the
+ * Chebyshev polynomials, the routine requires one more
+ * addition per loop than evaluating a nested polynomial of
+ * the same degree.
+ *
+ */
+
+template <typename Packet, int N>
+struct pchebevl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Packet run(Packet x, const typename unpacket_traits<Packet>::type coef[]) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ Packet b0 = pset1<Packet>(coef[0]);
+ Packet b1 = pset1<Packet>(static_cast<Scalar>(0.f));
+ Packet b2;
+
+ for (int i = 1; i < N; i++) {
+ b2 = b1;
+ b1 = b0;
+ b0 = psub(pmadd(x, b1, pset1<Packet>(coef[i])), b2);
+ }
+
+ return pmul(pset1<Packet>(static_cast<Scalar>(0.5f)), psub(b0, b2));
+ }
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
new file mode 100644
index 000000000..177a04e93
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -0,0 +1,110 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2019 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
+#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
+
+namespace Eigen {
+namespace internal {
+
+// Forward declarations of the generic math functions
+// implemented in GenericPacketMathFunctions.h
+// This is needed to workaround a circular dependency.
+
+/***************************************************************************
+ * Some generic implementations to be used by implementors
+***************************************************************************/
+
+/** Default implementation of pfrexp.
+ * It is expected to be called by implementers of template<> pfrexp.
+ */
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent);
+
+// Extracts the biased exponent value from Packet p, and casts the results to
+// a floating-point Packet type. Used by pfrexp_generic. Override this if
+// there is no unpacket_traits<Packet>::integer_packet.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& p);
+
+/** Default implementation of pldexp.
+ * It is expected to be called by implementers of template<> pldexp.
+ */
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent);
+
+/** \internal \returns log(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_float(const Packet _x);
+
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x);
+
+/** \internal \returns log(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_double(const Packet _x);
+
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x);
+
+/** \internal \returns log(1 + x) */
+template<typename Packet>
+Packet generic_plog1p(const Packet& x);
+
+/** \internal \returns exp(x)-1 */
+template<typename Packet>
+Packet generic_expm1(const Packet& x);
+
+/** \internal \returns exp(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_float(const Packet _x);
+
+/** \internal \returns exp(x) for double precision real numbers */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pexp_double(const Packet _x);
+
+/** \internal \returns sin(x) for single precision float */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psin_float(const Packet& x);
+
+/** \internal \returns cos(x) for single precision float */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet pcos_float(const Packet& x);
+
+/** \internal \returns sqrt(x) for complex types */
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psqrt_complex(const Packet& a);
+
+template <typename Packet, int N> struct ppolevl;
+
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
new file mode 100644
index 000000000..9f8e8cc1e
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -0,0 +1,942 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+//
+// The conversion routines are Copyright (c) Fabian Giesen, 2016.
+// The original license follows:
+//
+// Copyright (c) Fabian Giesen, 2016
+// All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted.
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+// Standard 16-bit float type, mostly useful for GPUs. Defines a new
+// type Eigen::half (inheriting either from CUDA's or HIP's __half struct) with
+// operator overloads such that it behaves basically as an arithmetic
+// type. It will be quite slow on CPUs (so it is recommended to stay
+// in fp32 for CPUs, except for simple parameter conversions, I/O
+// to disk and the likes), but fast on GPUs.
+
+
+#ifndef EIGEN_HALF_H
+#define EIGEN_HALF_H
+
+#include <sstream>
+
+#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+// When compiling with GPU support, the "__half_raw" base class as well as
+// some other routines are defined in the GPU compiler header files
+// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
+// As a consequence, we get compile failures when compiling Eigen with
+// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
+// Eigen with GPU support
+ #pragma push_macro("EIGEN_CONSTEXPR")
+ #undef EIGEN_CONSTEXPR
+ #define EIGEN_CONSTEXPR
+#endif
+
+#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
+ template <> \
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \
+ PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
+ return float2half(METHOD<PACKET_F>(half2float(_x))); \
+ }
+
+namespace Eigen {
+
+struct half;
+
+namespace half_impl {
+
+// We want to use the __half_raw struct from the HIP header file only during the device compile phase.
+// This is required because of a quirk in the way TensorFlow GPU builds are done.
+// When compiling TensorFlow source code with GPU support, files that
+// * contain GPU kernels (i.e. *.cu.cc files) are compiled via hipcc
+// * do not contain GPU kernels ( i.e. *.cc files) are compiled via gcc (typically)
+//
+// Tensorflow uses the Eigen::half type as its FP16 type, and there are functions that
+// * are defined in a file that gets compiled via hipcc AND
+// * have Eigen::half as a pass-by-value argument AND
+// * are called in a file that gets compiled via gcc
+//
+// In the scenario described above the caller and callee will see different versions
+// of the Eigen::half base class __half_raw, and they will be compiled by different compilers
+//
+// There appears to be an ABI mismatch between gcc and clang (which is called by hipcc) that results in
+// the callee getting corrupted values for the Eigen::half argument.
+//
+// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
+// this error, and hence the following convoluted #if condition
+#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
+// Make our own __half_raw definition that is similar to CUDA's.
+struct __half_raw {
+#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
+ // Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
+ // The element type for shared memory cannot have non-trivial constructors
+ // and hence the following special casing (which skips the zero-initilization).
+ // Note that this check gets done even in the host compilation phase, and
+ // hence the need for this
+ EIGEN_DEVICE_FUNC __half_raw() {}
+#else
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
+#endif
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {
+ }
+ __fp16 x;
+#else
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
+ numext::uint16_t x;
+#endif
+};
+
+#elif defined(EIGEN_HAS_HIP_FP16)
+ // Nothing to do here
+ // HIP fp16 header file has a definition for __half_raw
+#elif defined(EIGEN_HAS_CUDA_FP16)
+ #if EIGEN_CUDA_SDK_VER < 90000
+ // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
+ typedef __half __half_raw;
+ #endif // defined(EIGEN_HAS_CUDA_FP16)
+#elif defined(SYCL_DEVICE_ONLY)
+ typedef cl::sycl::half __half_raw;
+#endif
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
+
+struct half_base : public __half_raw {
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
+
+#if defined(EIGEN_HAS_GPU_FP16)
+ #if defined(EIGEN_HAS_HIP_FP16)
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
+ #elif defined(EIGEN_HAS_CUDA_FP16)
+ #if EIGEN_CUDA_SDK_VER >= 90000
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
+ #endif
+ #endif
+#endif
+};
+
+} // namespace half_impl
+
+// Class definition.
+struct half : public half_impl::half_base {
+
+ // Writing this out as separate #if-else blocks to make the code easier to follow
+ // The same applies to most #if-else blocks in this file
+#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
+ // Use the same base class for the following two scenarios
+ // * when compiling without GPU support enabled
+ // * during host compile phase when compiling with GPU support enabled
+ typedef half_impl::__half_raw __half_raw;
+#elif defined(EIGEN_HAS_HIP_FP16)
+ // Nothing to do here
+ // HIP fp16 header file has a definition for __half_raw
+#elif defined(EIGEN_HAS_CUDA_FP16)
+ // Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so
+ // (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within
+ // #if defined(EIGEN_HAS_CUDA_FP16) is needed
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000
+ typedef half_impl::__half_raw __half_raw;
+ #endif
+#endif
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
+
+#if defined(EIGEN_HAS_GPU_FP16)
+ #if defined(EIGEN_HAS_HIP_FP16)
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ #elif defined(EIGEN_HAS_CUDA_FP16)
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ #endif
+ #endif
+#endif
+
+
+ explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
+ : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
+ template<class T>
+ explicit EIGEN_DEVICE_FUNC half(T val)
+ : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
+ explicit EIGEN_DEVICE_FUNC half(float f)
+ : half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
+
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ template<typename RealScalar>
+ explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
+ : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
+
+ EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
+ return half_impl::half_to_float(*this);
+ }
+
+#if defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE)
+ EIGEN_DEVICE_FUNC operator __half() const {
+ ::__half_raw hr;
+ hr.x = x;
+ return __half(hr);
+ }
+#endif
+};
+
+} // end namespace Eigen
+
+namespace std {
+template<>
+struct numeric_limits<Eigen::half> {
+ static const bool is_specialized = true;
+ static const bool is_signed = true;
+ static const bool is_integer = false;
+ static const bool is_exact = false;
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const float_denorm_style has_denorm = denorm_present;
+ static const bool has_denorm_loss = false;
+ static const std::float_round_style round_style = std::round_to_nearest;
+ static const bool is_iec559 = false;
+ static const bool is_bounded = false;
+ static const bool is_modulo = false;
+ static const int digits = 11;
+ static const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
+ static const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
+ static const int radix = 2;
+ static const int min_exponent = -13;
+ static const int min_exponent10 = -4;
+ static const int max_exponent = 16;
+ static const int max_exponent10 = 4;
+ static const bool traps = true;
+ static const bool tinyness_before = false;
+
+ static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); }
+ static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
+ static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
+ static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); }
+ static Eigen::half round_error() { return Eigen::half(0.5); }
+ static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
+ static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
+ static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
+ static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); }
+};
+
+// If std::numeric_limits<T> is specialized, should also specialize
+// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
+// std::numeric_limits<const volatile T>
+// https://stackoverflow.com/a/16519653/
+template<>
+struct numeric_limits<const Eigen::half> : numeric_limits<Eigen::half> {};
+template<>
+struct numeric_limits<volatile Eigen::half> : numeric_limits<Eigen::half> {};
+template<>
+struct numeric_limits<const volatile Eigen::half> : numeric_limits<Eigen::half> {};
+} // end namespace std
+
+namespace Eigen {
+
+namespace half_impl {
+
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \
+ EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
+// Note: We deliberatly do *not* define this to 1 even if we have Arm's native
+// fp16 type since GPU halfs are rather different from native CPU halfs.
+// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
+#define EIGEN_HAS_NATIVE_FP16
+#endif
+
+// Intrinsics for native fp16 support. Note that on current hardware,
+// these are no faster than fp32 arithmetic (you need to use the half2
+// versions to get the ALU speed increased), but you do save the
+// conversion steps back and forth.
+
+#if defined(EIGEN_HAS_NATIVE_FP16)
+EIGEN_STRONG_INLINE __device__ half operator + (const half& a, const half& b) {
+#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ return __hadd(::__half(a), ::__half(b));
+#else
+ return __hadd(a, b);
+#endif
+}
+EIGEN_STRONG_INLINE __device__ half operator * (const half& a, const half& b) {
+ return __hmul(a, b);
+}
+EIGEN_STRONG_INLINE __device__ half operator - (const half& a, const half& b) {
+ return __hsub(a, b);
+}
+EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) {
+#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
+ return __hdiv(a, b);
+#else
+ float num = __half2float(a);
+ float denom = __half2float(b);
+ return __float2half(num / denom);
+#endif
+}
+EIGEN_STRONG_INLINE __device__ half operator - (const half& a) {
+ return __hneg(a);
+}
+EIGEN_STRONG_INLINE __device__ half& operator += (half& a, const half& b) {
+ a = a + b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator *= (half& a, const half& b) {
+ a = a * b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator -= (half& a, const half& b) {
+ a = a - b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ half& operator /= (half& a, const half& b) {
+ a = a / b;
+ return a;
+}
+EIGEN_STRONG_INLINE __device__ bool operator == (const half& a, const half& b) {
+ return __heq(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator != (const half& a, const half& b) {
+ return __hne(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator < (const half& a, const half& b) {
+ return __hlt(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator <= (const half& a, const half& b) {
+ return __hle(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) {
+ return __hgt(a, b);
+}
+EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
+ return __hge(a, b);
+}
+#endif
+
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
+ return half(vaddh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
+ return half(vmulh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
+ return half(vsubh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
+ return half(vdivh_f16(a.x, b.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
+ return half(vnegh_f16(a.x));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
+ a = half(vaddh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
+ a = half(vmulh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
+ a = half(vsubh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
+ a = half(vdivh_f16(a.x, b.x));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
+ return vceqh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
+ return !vceqh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
+ return vclth_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
+ return vcleh_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
+ return vcgth_f16(a.x, b.x);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
+ return vcgeh_f16(a.x, b.x);
+}
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
+
+#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
+// We need to provide emulated *host-side* FP16 operators for clang.
+#pragma push_macro("EIGEN_DEVICE_FUNC")
+#undef EIGEN_DEVICE_FUNC
+#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
+#define EIGEN_DEVICE_FUNC __host__
+#else // both host and device need emulated ops.
+#define EIGEN_DEVICE_FUNC __host__ __device__
+#endif
+#endif
+
+// Definitions for CPUs and older HIP+CUDA, mostly working through conversion
+// to/from fp32.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
+ return half(float(a) + float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
+ return half(float(a) * float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
+ return half(float(a) - float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
+ return half(float(a) / float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
+ half result;
+ result.x = a.x ^ 0x8000;
+ return result;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
+ a = half(float(a) + float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
+ a = half(float(a) * float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
+ a = half(float(a) - float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
+ a = half(float(a) / float(b));
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
+ return numext::equal_strict(float(a),float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
+ return numext::not_equal_strict(float(a), float(b));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
+ return float(a) < float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
+ return float(a) <= float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
+ return float(a) > float(b);
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
+ return float(a) >= float(b);
+}
+
+#if defined(__clang__) && defined(__CUDA__)
+#pragma pop_macro("EIGEN_DEVICE_FUNC")
+#endif
+#endif // Emulate support for half floats
+
+// Division by an index. Do it in full float precision to avoid accuracy
+// issues in converting the denominator to half.
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
+ return half(static_cast<float>(a) / static_cast<float>(b));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a) {
+ a += half(1);
+ return a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a) {
+ a -= half(1);
+ return a;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a, int) {
+ half original_value = a;
+ ++a;
+ return original_value;
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
+ half original_value = a;
+ --a;
+ return original_value;
+}
+
+// Conversion routines, including fallbacks for the host or older CUDA.
+// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of
+// these in hardware. If we need more performance on older/other CPUs, they are
+// also possible to vectorize directly.
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
+ // We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
+ // in the hip_fp16 header file, and that will trigger a compile error
+ // On the other hand, having anything but a return statement also triggers a compile error
+ // because this is constexpr function.
+ // Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out
+ // of this catch22 by having separate bodies for GPU / non GPU
+#if defined(EIGEN_HAS_GPU_FP16)
+ __half_raw h;
+ h.x = x;
+ return h;
+#else
+ return __half_raw(x);
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const __half_raw& h) {
+ // HIP/CUDA/Default have a member 'x' of type uint16_t.
+ // For ARM64 native half, the member 'x' is of type __fp16, so we need to bit-cast.
+ // For SYCL, cl::sycl::half is _Float16, so cast directly.
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return numext::bit_cast<numext::uint16_t>(h.x);
+#elif defined(SYCL_DEVICE_ONLY)
+ return numext::bit_cast<numext::uint16_t>(h);
+#else
+ return h.x;
+#endif
+}
+
+union float32_bits {
+ unsigned int u;
+ float f;
+};
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ __half tmp_ff = __float2half(ff);
+ return *(__half_raw*)&tmp_ff;
+
+#elif defined(EIGEN_HAS_FP16_C)
+ __half_raw h;
+ h.x = _cvtss_sh(ff, 0);
+ return h;
+
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ __half_raw h;
+ h.x = static_cast<__fp16>(ff);
+ return h;
+
+#else
+ float32_bits f; f.f = ff;
+
+ const float32_bits f32infty = { 255 << 23 };
+ const float32_bits f16max = { (127 + 16) << 23 };
+ const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
+ unsigned int sign_mask = 0x80000000u;
+ __half_raw o;
+ o.x = static_cast<numext::uint16_t>(0x0u);
+
+ unsigned int sign = f.u & sign_mask;
+ f.u ^= sign;
+
+ // NOTE all the integer compares in this function can be safely
+ // compiled into signed compares since all operands are below
+ // 0x80000000. Important if you want fast straight SSE2 code
+ // (since there's no unsigned PCMPGTD).
+
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
+ o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
+ } else { // (De)normalized number or zero
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
+ // use a magic value to align our 10 mantissa bits at the bottom of
+ // the float. as long as FP addition is round-to-nearest-even this
+ // just works.
+ f.f += denorm_magic.f;
+
+ // and one integer subtract of the bias later, we have our final float!
+ o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
+ } else {
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
+
+ // update exponent, rounding bias part 1
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
+ // without arithmetic overflow.
+ f.u += 0xc8000fffU;
+ // rounding bias part 2
+ f.u += mant_odd;
+ // take the bits!
+ o.x = static_cast<numext::uint16_t>(f.u >> 13);
+ }
+ }
+
+ o.x |= static_cast<numext::uint16_t>(sign >> 16);
+ return o;
+#endif
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __half2float(h);
+#elif defined(EIGEN_HAS_FP16_C)
+ return _cvtsh_ss(h.x);
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return static_cast<float>(h.x);
+#else
+ const float32_bits magic = { 113 << 23 };
+ const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
+ float32_bits o;
+
+ o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
+ unsigned int exp = shifted_exp & o.u; // just the exponent
+ o.u += (127 - 15) << 23; // exponent adjust
+
+ // handle exponent special cases
+ if (exp == shifted_exp) { // Inf/NaN?
+ o.u += (128 - 16) << 23; // extra exp adjust
+ } else if (exp == 0) { // Zero/Denormal?
+ o.u += 1 << 23; // extra exp adjust
+ o.f -= magic.f; // renormalize
+ }
+
+ o.u |= (h.x & 0x8000) << 16; // sign bit
+ return o.f;
+#endif
+}
+
+// --- standard functions ---
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
+#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
+ return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
+#else
+ return (a.x & 0x7fff) == 0x7c00;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hisnan(a);
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
+#else
+ return (a.x & 0x7fff) > 0x7c00;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
+ return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return half(vabsh_f16(a.x));
+#else
+ half result;
+ result.x = a.x & 0x7FFF;
+ return result;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hexp(a));
+#else
+ return half(::expf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) {
+ return half(numext::expm1(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return half(::hlog(a));
+#else
+ return half(::logf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
+ return half(numext::log1p(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
+ return half(::log10f(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) {
+ return half(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hsqrt(a));
+#else
+ return half(::sqrtf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
+ return half(::powf(float(a), float(b)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
+ return half(::sinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) {
+ return half(::cosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) {
+ return half(::tanf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) {
+ return half(::tanhf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) {
+ return half(::asinf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) {
+ return half(::acosf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hfloor(a));
+#else
+ return half(::floorf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) {
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+ return half(hceil(a));
+#else
+ return half(::ceilf(float(a)));
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) {
+ return half(::rintf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) {
+ return half(::roundf(float(a)));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
+ return half(::fmodf(float(a), float(b)));
+}
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hlt(b, a) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f2 < f1 ? b : a;
+#endif
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) {
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __hlt(a, b) ? b : a;
+#else
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return f1 < f2 ? b : a;
+#endif
+}
+
+#ifndef EIGEN_NO_IO
+EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) {
+ os << static_cast<float>(v);
+ return os;
+}
+#endif
+
+} // end namespace half_impl
+
+// import Eigen::half_impl::half into Eigen namespace
+// using half_impl::half;
+
+namespace internal {
+
+template<>
+struct random_default_impl<half, false, false>
+{
+ static inline half run(const half& x, const half& y)
+ {
+ return x + (y-x) * half(float(std::rand()) / float(RAND_MAX));
+ }
+ static inline half run()
+ {
+ return run(half(-1.f), half(1.f));
+ }
+};
+
+template<> struct is_arithmetic<half> { enum { value = true }; };
+
+} // end namespace internal
+
+template<> struct NumTraits<Eigen::half>
+ : GenericNumTraits<Eigen::half>
+{
+ enum {
+ IsSigned = true,
+ IsInteger = false,
+ IsComplex = false,
+ RequireInitialization = false
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
+ return half_impl::raw_uint16_to_half(0x0800);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
+ return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
+ return half_impl::raw_uint16_to_half(0x7bff);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
+ return half_impl::raw_uint16_to_half(0xfbff);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
+ return half_impl::raw_uint16_to_half(0x7c00);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
+ return half_impl::raw_uint16_to_half(0x7e00);
+ }
+};
+
+} // end namespace Eigen
+
+#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ #pragma pop_macro("EIGEN_CONSTEXPR")
+#endif
+
+namespace Eigen {
+namespace numext {
+
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::half& h) {
+ return (half_impl::isnan)(h);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::half& h) {
+ return (half_impl::isinf)(h);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::half& h) {
+ return (half_impl::isfinite)(h);
+}
+
+#endif
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half bit_cast<Eigen::half, uint16_t>(const uint16_t& src) {
+ return Eigen::half(Eigen::half_impl::raw_uint16_to_half(src));
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(const Eigen::half& src) {
+ return Eigen::half_impl::raw_half_as_uint16(src);
+}
+
+} // namespace numext
+} // namespace Eigen
+
+// Add the missing shfl* intrinsics.
+// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
+// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
+//
+// HIP and CUDA prior to SDK 9.0 define
+// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
+// CUDA since 9.0 deprecates those and instead defines
+// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
+// with native support for __half and __nv_bfloat16
+//
+// Note that the following are __device__ - only functions.
+#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) \
+ || defined(EIGEN_HIPCC)
+
+#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) {
+ const __half h = var;
+ return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
+}
+
+#else // HIP or CUDA SDK < 9.0
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
+}
+
+__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
+ const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
+ return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
+}
+
+#endif // HIP vs CUDA
+#endif // __shfl*
+
+// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
+#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) \
+ || defined(EIGEN_HIPCC)
+EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
+ return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
+}
+#endif // __ldg
+
+#if EIGEN_HAS_STD_HASH
+namespace std {
+template <>
+struct hash<Eigen::half> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const {
+ return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
+ }
+};
+} // end namespace std
+#endif
+
+#endif // EIGEN_HALF_H
diff --git a/Eigen/src/Core/arch/Default/Settings.h b/Eigen/src/Core/arch/Default/Settings.h
index 097373c84..a5c3ada4c 100644
--- a/Eigen/src/Core/arch/Default/Settings.h
+++ b/Eigen/src/Core/arch/Default/Settings.h
@@ -21,7 +21,7 @@
* it does not correspond to the number of iterations or the number of instructions
*/
#ifndef EIGEN_UNROLLING_LIMIT
-#define EIGEN_UNROLLING_LIMIT 100
+#define EIGEN_UNROLLING_LIMIT 110
#endif
/** Defines the threshold between a "small" and a "large" matrix.
diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h
new file mode 100644
index 000000000..fb8183b78
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/TypeCasting.h
@@ -0,0 +1,120 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_GENERIC_TYPE_CASTING_H
+#define EIGEN_GENERIC_TYPE_CASTING_H
+
+namespace Eigen {
+
+namespace internal {
+
+template<>
+struct scalar_cast_op<float, Eigen::half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(a);
+ #else
+ return Eigen::half(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(static_cast<float>(a));
+ #else
+ return Eigen::half(static_cast<float>(a));
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::half, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __half2float(a);
+ #else
+ return static_cast<float>(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::half, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<float, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
+ return Eigen::bfloat16(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
+ return Eigen::bfloat16(static_cast<float>(a));
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::bfloat16, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
+ return static_cast<float>(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+}
+}
+
+#endif // EIGEN_GENERIC_TYPE_CASTING_H
diff --git a/Eigen/src/Core/arch/CUDA/MathFunctions.h b/Eigen/src/Core/arch/GPU/MathFunctions.h
index 0348b41db..d2b3a2568 100644
--- a/Eigen/src/Core/arch/CUDA/MathFunctions.h
+++ b/Eigen/src/Core/arch/GPU/MathFunctions.h
@@ -7,8 +7,8 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-#ifndef EIGEN_MATH_FUNCTIONS_CUDA_H
-#define EIGEN_MATH_FUNCTIONS_CUDA_H
+#ifndef EIGEN_MATH_FUNCTIONS_GPU_H
+#define EIGEN_MATH_FUNCTIONS_GPU_H
namespace Eigen {
@@ -17,7 +17,7 @@ namespace internal {
// Make sure this is only available when targeting a GPU: we don't want to
// introduce conflicts between these packet_traits definitions and the ones
// we'll use on the host side (SSE, AVX, ...)
-#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
+#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 plog<float4>(const float4& a)
{
@@ -57,6 +57,18 @@ double2 pexp<double2>(const double2& a)
}
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+float4 pexpm1<float4>(const float4& a)
+{
+ return make_float4(expm1f(a.x), expm1f(a.y), expm1f(a.z), expm1f(a.w));
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+double2 pexpm1<double2>(const double2& a)
+{
+ return make_double2(expm1(a.x), expm1(a.y));
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float4 psqrt<float4>(const float4& a)
{
return make_float4(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z), sqrtf(a.w));
@@ -88,4 +100,4 @@ double2 prsqrt<double2>(const double2& a)
} // end namespace Eigen
-#endif // EIGEN_MATH_FUNCTIONS_CUDA_H
+#endif // EIGEN_MATH_FUNCTIONS_GPU_H
diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h
new file mode 100644
index 000000000..689110ded
--- /dev/null
+++ b/Eigen/src/Core/arch/GPU/PacketMath.h
@@ -0,0 +1,1685 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_PACKET_MATH_GPU_H
+#define EIGEN_PACKET_MATH_GPU_H
+
+namespace Eigen {
+
+namespace internal {
+
+// Read-only data cached load available.
+#if defined(EIGEN_HIP_DEVICE_COMPILE) || (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350)
+#define EIGEN_GPU_HAS_LDG 1
+#endif
+
+// FP16 math available.
+#if (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530)
+#define EIGEN_CUDA_HAS_FP16_ARITHMETIC 1
+#endif
+
+#if defined(EIGEN_HIP_DEVICE_COMPILE) || defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
+#define EIGEN_GPU_HAS_FP16_ARITHMETIC 1
+#endif
+
+// Make sure this is only available when targeting a GPU: we don't want to
+// introduce conflicts between these packet_traits definitions and the ones
+// we'll use on the host side (SSE, AVX, ...)
+#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
+
+template<> struct is_arithmetic<float4> { enum { value = true }; };
+template<> struct is_arithmetic<double2> { enum { value = true }; };
+
+template<> struct packet_traits<float> : default_packet_traits
+{
+ typedef float4 type;
+ typedef float4 half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size=4,
+ HasHalfPacket = 0,
+
+ HasDiv = 1,
+ HasSin = 0,
+ HasCos = 0,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasLGamma = 1,
+ HasDiGamma = 1,
+ HasZeta = 1,
+ HasPolygamma = 1,
+ HasErf = 1,
+ HasErfc = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+ HasIGamma = 1,
+ HasIGammaDerA = 1,
+ HasGammaSampleDerAlpha = 1,
+ HasIGammac = 1,
+ HasBetaInc = 1,
+
+ HasBlend = 0,
+ HasFloor = 1,
+ };
+};
+
+template<> struct packet_traits<double> : default_packet_traits
+{
+ typedef double2 type;
+ typedef double2 half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size=2,
+ HasHalfPacket = 0,
+
+ HasDiv = 1,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasLGamma = 1,
+ HasDiGamma = 1,
+ HasZeta = 1,
+ HasPolygamma = 1,
+ HasErf = 1,
+ HasErfc = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+ HasIGamma = 1,
+ HasIGammaDerA = 1,
+ HasGammaSampleDerAlpha = 1,
+ HasIGammac = 1,
+ HasBetaInc = 1,
+
+ HasBlend = 0,
+ HasFloor = 1,
+ };
+};
+
+
+template<> struct unpacket_traits<float4> { typedef float type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef float4 half; };
+template<> struct unpacket_traits<double2> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef double2 half; };
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pset1<float4>(const float& from) {
+ return make_float4(from, from, from, from);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pset1<double2>(const double& from) {
+ return make_double2(from, from);
+}
+
+// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
+// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
+// of the functions, while the latter can only deal with one of them.
+#if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
+namespace {
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_and(const float& a,
+ const float& b) {
+ return __int_as_float(__float_as_int(a) & __float_as_int(b));
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_and(const double& a,
+ const double& b) {
+ return __longlong_as_double(__double_as_longlong(a) &
+ __double_as_longlong(b));
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_or(const float& a,
+ const float& b) {
+ return __int_as_float(__float_as_int(a) | __float_as_int(b));
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_or(const double& a,
+ const double& b) {
+ return __longlong_as_double(__double_as_longlong(a) |
+ __double_as_longlong(b));
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_xor(const float& a,
+ const float& b) {
+ return __int_as_float(__float_as_int(a) ^ __float_as_int(b));
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_xor(const double& a,
+ const double& b) {
+ return __longlong_as_double(__double_as_longlong(a) ^
+ __double_as_longlong(b));
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_andnot(const float& a,
+ const float& b) {
+ return __int_as_float(__float_as_int(a) & ~__float_as_int(b));
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_andnot(const double& a,
+ const double& b) {
+ return __longlong_as_double(__double_as_longlong(a) &
+ ~__double_as_longlong(b));
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float eq_mask(const float& a,
+ const float& b) {
+ return __int_as_float(a == b ? 0xffffffffu : 0u);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double eq_mask(const double& a,
+ const double& b) {
+ return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a,
+ const float& b) {
+ return __int_as_float(a < b ? 0xffffffffu : 0u);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a,
+ const double& b) {
+ return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
+}
+
+} // namespace
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pand<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y),
+ bitwise_and(a.z, b.z), bitwise_and(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pand<double2>(const double2& a,
+ const double2& b) {
+ return make_double2(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 por<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y),
+ bitwise_or(a.z, b.z), bitwise_or(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 por<double2>(const double2& a,
+ const double2& b) {
+ return make_double2(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pxor<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y),
+ bitwise_xor(a.z, b.z), bitwise_xor(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pxor<double2>(const double2& a,
+ const double2& b) {
+ return make_double2(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pandnot<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y),
+ bitwise_andnot(a.z, b.z), bitwise_andnot(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pandnot<double2>(const double2& a, const double2& b) {
+ return make_double2(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_eq<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(eq_mask(a.x, b.x), eq_mask(a.y, b.y), eq_mask(a.z, b.z),
+ eq_mask(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z),
+ lt_mask(a.w, b.w));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pcmp_eq<double2>(const double2& a, const double2& b) {
+ return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pcmp_lt<double2>(const double2& a, const double2& b) {
+ return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
+}
+#endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
+ return make_float4(a, a+1, a+2, a+3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 plset<double2>(const double& a) {
+ return make_double2(a, a+1);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 padd<float4>(const float4& a, const float4& b) {
+ return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 padd<double2>(const double2& a, const double2& b) {
+ return make_double2(a.x+b.x, a.y+b.y);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 psub<float4>(const float4& a, const float4& b) {
+ return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 psub<double2>(const double2& a, const double2& b) {
+ return make_double2(a.x-b.x, a.y-b.y);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pnegate(const float4& a) {
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pnegate(const double2& a) {
+ return make_double2(-a.x, -a.y);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pconj(const float4& a) { return a; }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pconj(const double2& a) { return a; }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmul<float4>(const float4& a, const float4& b) {
+ return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmul<double2>(const double2& a, const double2& b) {
+ return make_double2(a.x*b.x, a.y*b.y);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pdiv<float4>(const float4& a, const float4& b) {
+ return make_float4(a.x/b.x, a.y/b.y, a.z/b.z, a.w/b.w);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pdiv<double2>(const double2& a, const double2& b) {
+ return make_double2(a.x/b.x, a.y/b.y);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmin<float4>(const float4& a, const float4& b) {
+ return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmin<double2>(const double2& a, const double2& b) {
+ return make_double2(fmin(a.x, b.x), fmin(a.y, b.y));
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmax<float4>(const float4& a, const float4& b) {
+ return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmax<double2>(const double2& a, const double2& b) {
+ return make_double2(fmax(a.x, b.x), fmax(a.y, b.y));
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pload<float4>(const float* from) {
+ return *reinterpret_cast<const float4*>(from);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pload<double2>(const double* from) {
+ return *reinterpret_cast<const double2*>(from);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploadu<float4>(const float* from) {
+ return make_float4(from[0], from[1], from[2], from[3]);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploadu<double2>(const double* from) {
+ return make_double2(from[0], from[1]);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploaddup<float4>(const float* from) {
+ return make_float4(from[0], from[0], from[1], from[1]);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploaddup<double2>(const double* from) {
+ return make_double2(from[0], from[0]);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<float>(float* to, const float4& from) {
+ *reinterpret_cast<float4*>(to) = from;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<double>(double* to, const double2& from) {
+ *reinterpret_cast<double2*>(to) = from;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const float4& from) {
+ to[0] = from.x;
+ to[1] = from.y;
+ to[2] = from.z;
+ to[3] = from.w;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const double2& from) {
+ to[0] = from.x;
+ to[1] = from.y;
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Aligned>(const float* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ return __ldg((const float4*)from);
+#else
+ return make_float4(from[0], from[1], from[2], from[3]);
+#endif
+}
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Aligned>(const double* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ return __ldg((const double2*)from);
+#else
+ return make_double2(from[0], from[1]);
+#endif
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Unaligned>(const float* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ return make_float4(__ldg(from+0), __ldg(from+1), __ldg(from+2), __ldg(from+3));
+#else
+ return make_float4(from[0], from[1], from[2], from[3]);
+#endif
+}
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Unaligned>(const double* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ return make_double2(__ldg(from+0), __ldg(from+1));
+#else
+ return make_double2(from[0], from[1]);
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC inline float4 pgather<float, float4>(const float* from, Index stride) {
+ return make_float4(from[0*stride], from[1*stride], from[2*stride], from[3*stride]);
+}
+
+template<> EIGEN_DEVICE_FUNC inline double2 pgather<double, double2>(const double* from, Index stride) {
+ return make_double2(from[0*stride], from[1*stride]);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<float, float4>(float* to, const float4& from, Index stride) {
+ to[stride*0] = from.x;
+ to[stride*1] = from.y;
+ to[stride*2] = from.z;
+ to[stride*3] = from.w;
+}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<double, double2>(double* to, const double2& from, Index stride) {
+ to[stride*0] = from.x;
+ to[stride*1] = from.y;
+}
+
+template<> EIGEN_DEVICE_FUNC inline float pfirst<float4>(const float4& a) {
+ return a.x;
+}
+template<> EIGEN_DEVICE_FUNC inline double pfirst<double2>(const double2& a) {
+ return a.x;
+}
+
+template<> EIGEN_DEVICE_FUNC inline float predux<float4>(const float4& a) {
+ return a.x + a.y + a.z + a.w;
+}
+template<> EIGEN_DEVICE_FUNC inline double predux<double2>(const double2& a) {
+ return a.x + a.y;
+}
+
+template<> EIGEN_DEVICE_FUNC inline float predux_max<float4>(const float4& a) {
+ return fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w));
+}
+template<> EIGEN_DEVICE_FUNC inline double predux_max<double2>(const double2& a) {
+ return fmax(a.x, a.y);
+}
+
+template<> EIGEN_DEVICE_FUNC inline float predux_min<float4>(const float4& a) {
+ return fminf(fminf(a.x, a.y), fminf(a.z, a.w));
+}
+template<> EIGEN_DEVICE_FUNC inline double predux_min<double2>(const double2& a) {
+ return fmin(a.x, a.y);
+}
+
+template<> EIGEN_DEVICE_FUNC inline float predux_mul<float4>(const float4& a) {
+ return a.x * a.y * a.z * a.w;
+}
+template<> EIGEN_DEVICE_FUNC inline double predux_mul<double2>(const double2& a) {
+ return a.x * a.y;
+}
+
+template<> EIGEN_DEVICE_FUNC inline float4 pabs<float4>(const float4& a) {
+ return make_float4(fabsf(a.x), fabsf(a.y), fabsf(a.z), fabsf(a.w));
+}
+template<> EIGEN_DEVICE_FUNC inline double2 pabs<double2>(const double2& a) {
+ return make_double2(fabs(a.x), fabs(a.y));
+}
+
+template<> EIGEN_DEVICE_FUNC inline float4 pfloor<float4>(const float4& a) {
+ return make_float4(floorf(a.x), floorf(a.y), floorf(a.z), floorf(a.w));
+}
+template<> EIGEN_DEVICE_FUNC inline double2 pfloor<double2>(const double2& a) {
+ return make_double2(floor(a.x), floor(a.y));
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<float4,4>& kernel) {
+ float tmp = kernel.packet[0].y;
+ kernel.packet[0].y = kernel.packet[1].x;
+ kernel.packet[1].x = tmp;
+
+ tmp = kernel.packet[0].z;
+ kernel.packet[0].z = kernel.packet[2].x;
+ kernel.packet[2].x = tmp;
+
+ tmp = kernel.packet[0].w;
+ kernel.packet[0].w = kernel.packet[3].x;
+ kernel.packet[3].x = tmp;
+
+ tmp = kernel.packet[1].z;
+ kernel.packet[1].z = kernel.packet[2].y;
+ kernel.packet[2].y = tmp;
+
+ tmp = kernel.packet[1].w;
+ kernel.packet[1].w = kernel.packet[3].y;
+ kernel.packet[3].y = tmp;
+
+ tmp = kernel.packet[2].w;
+ kernel.packet[2].w = kernel.packet[3].z;
+ kernel.packet[3].z = tmp;
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<double2,2>& kernel) {
+ double tmp = kernel.packet[0].y;
+ kernel.packet[0].y = kernel.packet[1].x;
+ kernel.packet[1].x = tmp;
+}
+
+#endif // defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
+
+// Packet4h2 must be defined in the macro without EIGEN_CUDA_ARCH, meaning
+// its corresponding packet_traits<Eigen::half> must be visible on host.
+#if defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16)
+
+typedef ulonglong2 Packet4h2;
+template<> struct unpacket_traits<Packet4h2> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h2 half; };
+template<> struct is_arithmetic<Packet4h2> { enum { value = true }; };
+
+template<> struct unpacket_traits<half2> { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef half2 half; };
+template<> struct is_arithmetic<half2> { enum { value = true }; };
+
+template<> struct packet_traits<Eigen::half> : default_packet_traits
+{
+ typedef Packet4h2 type;
+ typedef Packet4h2 half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size=8,
+ HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasExp = 1,
+ HasExpm1 = 1,
+ HasLog = 1,
+ HasLog1p = 1
+ };
+};
+
+namespace {
+// This is equivalent to make_half2, which is undocumented and doesn't seem to always exist.
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 combine_half(const __half& a, const __half& b) {
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+ return __halves2half2(a, b);
+#else
+ // Round-about way since __halves2half2 is a __device__ function.
+ return __floats2half2_rn(__half2float(a), __half2float(b));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE __half get_half2_low(const half2& a) {
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+ return __low2half(a);
+#else
+ return __float2half(__low2float(a));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE __half get_half2_high(const half2& a) {
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+ return __high2half(a);
+#else
+ return __float2half(__high2float(a));
+#endif
+}
+} // namespace
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1<half2>(const Eigen::half& from) {
+#if defined(EIGEN_GPU_COMPILE_PHASE)
+ return __half2half2(from);
+#else
+ const float f = __half2float(from);
+ return __floats2half2_rn(f, f);
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pset1<Packet4h2>(const Eigen::half& from) {
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = pset1<half2>(from);
+ p_alias[1] = pset1<half2>(from);
+ p_alias[2] = pset1<half2>(from);
+ p_alias[3] = pset1<half2>(from);
+ return r;
+}
+
+// We now need this visible on both host and device.
+// #if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
+namespace {
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload(const Eigen::half* from) {
+ return *reinterpret_cast<const half2*>(from);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu(const Eigen::half* from) {
+ return combine_half(from[0], from[1]);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup(const Eigen::half* from) {
+ return combine_half(from[0], from[0]);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(Eigen::half* to,
+ const half2& from) {
+ *reinterpret_cast<half2*>(to) = from;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to,
+ const half2& from) {
+ to[0] = get_half2_low(from);
+ to[1] = get_half2_high(from);
+}
+
+
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_aligned(
+ const Eigen::half* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ // Input is guaranteed to be properly aligned.
+ return __ldg(reinterpret_cast<const half2*>(from));
+#else
+ return combine_half(*(from+0), *(from+1));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_unaligned(
+ const Eigen::half* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ return __halves2half2(__ldg(from+0), __ldg(from+1));
+#else
+ return combine_half(*(from+0), *(from+1));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather(const Eigen::half* from,
+ Index stride) {
+ return combine_half(from[0*stride], from[1*stride]);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(
+ Eigen::half* to, const half2& from, Index stride) {
+ to[stride*0] = get_half2_low(from);
+ to[stride*1] = get_half2_high(from);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst(const half2& a) {
+ return get_half2_low(a);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& a) {
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & 0x7FFF);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & 0x7FFF);
+ return combine_half(result1, result2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue(const half2& /*a*/) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ return pset1<half2>(true_half);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero(const half2& /*a*/) {
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ return pset1<half2>(false_half);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<half2,2>& kernel) {
+ __half a1 = get_half2_low(kernel.packet[0]);
+ __half a2 = get_half2_high(kernel.packet[0]);
+ __half b1 = get_half2_low(kernel.packet[1]);
+ __half b2 = get_half2_high(kernel.packet[1]);
+ kernel.packet[0] = combine_half(a1, b1);
+ kernel.packet[1] = combine_half(a2, b2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen::half& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __halves2half2(a, __hadd(a, __float2half(1.0f)));
+#else
+ float f = __half2float(a) + 1.0f;
+ return combine_half(a, __float2half(f));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask,
+ const half2& a,
+ const half2& b) {
+ half mask_low = get_half2_low(mask);
+ half mask_high = get_half2_high(mask);
+ half result_low = mask_low == half(0) ? get_half2_low(b) : get_half2_low(a);
+ half result_high = mask_high == half(0) ? get_half2_high(b) : get_half2_high(a);
+ return combine_half(result_low, result_high);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a,
+ const half2& b) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half eq1 = __half2float(a1) == __half2float(b1) ? true_half : false_half;
+ half eq2 = __half2float(a2) == __half2float(b2) ? true_half : false_half;
+ return combine_half(eq1, eq2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a,
+ const half2& b) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half;
+ half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half;
+ return combine_half(eq1, eq2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a,
+ const half2& b) {
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & b2.x);
+ return combine_half(result1, result2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a,
+ const half2& b) {
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x | b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x | b2.x);
+ return combine_half(result1, result2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a,
+ const half2& b) {
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x ^ b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x ^ b2.x);
+ return combine_half(result1, result2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a,
+ const half2& b) {
+ half a1 = get_half2_low(a);
+ half a2 = get_half2_high(a);
+ half b1 = get_half2_low(b);
+ half b2 = get_half2_high(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & ~b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & ~b2.x);
+ return combine_half(result1, result2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hadd2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 + b1;
+ float r2 = a2 + b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hsub2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 - b1;
+ float r2 = a2 - b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hneg2(a);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return __floats2half2_rn(-a1, -a2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; }
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hmul2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 * b1;
+ float r2 = a2 * b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& a,
+ const half2& b,
+ const half2& c) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hfma2(a, b, c);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float c1 = __low2float(c);
+ float c2 = __high2float(c);
+ float r1 = a1 * b1 + c1;
+ float r2 = a2 * b2 + c2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __h2div(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 / b1;
+ float r2 = a2 / b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a,
+ const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 < b1 ? get_half2_low(a) : get_half2_low(b);
+ __half r2 = a2 < b2 ? get_half2_high(a) : get_half2_high(b);
+ return combine_half(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a,
+ const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 > b1 ? get_half2_low(a) : get_half2_low(b);
+ __half r2 = a2 > b2 ? get_half2_high(a) : get_half2_high(b);
+ return combine_half(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const half2& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hadd(__low2half(a), __high2half(a));
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return Eigen::half(__float2half(a1 + a2));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(const half2& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hgt(first, second) ? first : second;
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return a1 > a2 ? get_half2_low(a) : get_half2_high(a);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(const half2& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hlt(first, second) ? first : second;
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return a1 < a2 ? get_half2_low(a) : get_half2_high(a);
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(const half2& a) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hmul(__low2half(a), __high2half(a));
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return Eigen::half(__float2half(a1 * a2));
+#endif
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = log1pf(a1);
+ float r2 = log1pf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = expm1f(a1);
+ float r2 = expm1f(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 plog(const half2& a) {
+ return h2log(a);
+}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 pexp(const half2& a) {
+ return h2exp(a);
+}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 psqrt(const half2& a) {
+ return h2sqrt(a);
+}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 prsqrt(const half2& a) {
+ return h2rsqrt(a);
+}
+
+#else
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = logf(a1);
+ float r2 = logf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = expf(a1);
+ float r2 = expf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = sqrtf(a1);
+ float r2 = sqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = rsqrtf(a1);
+ float r2 = rsqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+#endif
+} // namespace
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pload<Packet4h2>(const Eigen::half* from) {
+ return *reinterpret_cast<const Packet4h2*>(from);
+}
+
+// unaligned load;
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+ploadu<Packet4h2>(const Eigen::half* from) {
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = ploadu(from + 0);
+ p_alias[1] = ploadu(from + 2);
+ p_alias[2] = ploadu(from + 4);
+ p_alias[3] = ploadu(from + 6);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+ploaddup<Packet4h2>(const Eigen::half* from) {
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = ploaddup(from + 0);
+ p_alias[1] = ploaddup(from + 1);
+ p_alias[2] = ploaddup(from + 2);
+ p_alias[3] = ploaddup(from + 3);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<Eigen::half>(
+ Eigen::half* to, const Packet4h2& from) {
+ *reinterpret_cast<Packet4h2*>(to) = from;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(
+ Eigen::half* to, const Packet4h2& from) {
+ const half2* from_alias = reinterpret_cast<const half2*>(&from);
+ pstoreu(to + 0,from_alias[0]);
+ pstoreu(to + 2,from_alias[1]);
+ pstoreu(to + 4,from_alias[2]);
+ pstoreu(to + 6,from_alias[3]);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2
+ploadt_ro<Packet4h2, Aligned>(const Eigen::half* from) {
+#if defined(EIGEN_GPU_HAS_LDG)
+ Packet4h2 r;
+ r = __ldg(reinterpret_cast<const Packet4h2*>(from));
+ return r;
+#else
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ r_alias[0] = ploadt_ro_aligned(from + 0);
+ r_alias[1] = ploadt_ro_aligned(from + 2);
+ r_alias[2] = ploadt_ro_aligned(from + 4);
+ r_alias[3] = ploadt_ro_aligned(from + 6);
+ return r;
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2
+ploadt_ro<Packet4h2, Unaligned>(const Eigen::half* from) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ r_alias[0] = ploadt_ro_unaligned(from + 0);
+ r_alias[1] = ploadt_ro_unaligned(from + 2);
+ r_alias[2] = ploadt_ro_unaligned(from + 4);
+ r_alias[3] = ploadt_ro_unaligned(from + 6);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pgather<Eigen::half, Packet4h2>(const Eigen::half* from, Index stride) {
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = combine_half(from[0 * stride], from[1 * stride]);
+ p_alias[1] = combine_half(from[2 * stride], from[3 * stride]);
+ p_alias[2] = combine_half(from[4 * stride], from[5 * stride]);
+ p_alias[3] = combine_half(from[6 * stride], from[7 * stride]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h2>(
+ Eigen::half* to, const Packet4h2& from, Index stride) {
+ const half2* from_alias = reinterpret_cast<const half2*>(&from);
+ pscatter(to + stride * 0, from_alias[0], stride);
+ pscatter(to + stride * 2, from_alias[1], stride);
+ pscatter(to + stride * 4, from_alias[2], stride);
+ pscatter(to + stride * 6, from_alias[3], stride);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h2>(
+ const Packet4h2& a) {
+ return pfirst(*(reinterpret_cast<const half2*>(&a)));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pabs<Packet4h2>(
+ const Packet4h2& a) {
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ p_alias[0] = pabs(a_alias[0]);
+ p_alias[1] = pabs(a_alias[1]);
+ p_alias[2] = pabs(a_alias[2]);
+ p_alias[3] = pabs(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ptrue<Packet4h2>(
+ const Packet4h2& /*a*/) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ return pset1<Packet4h2>(true_half);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pzero<Packet4h2>(const Packet4h2& /*a*/) {
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ return pset1<Packet4h2>(false_half);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_double(
+ double* d_row0, double* d_row1, double* d_row2, double* d_row3,
+ double* d_row4, double* d_row5, double* d_row6, double* d_row7) {
+ double d_tmp;
+ d_tmp = d_row0[1];
+ d_row0[1] = d_row4[0];
+ d_row4[0] = d_tmp;
+
+ d_tmp = d_row1[1];
+ d_row1[1] = d_row5[0];
+ d_row5[0] = d_tmp;
+
+ d_tmp = d_row2[1];
+ d_row2[1] = d_row6[0];
+ d_row6[0] = d_tmp;
+
+ d_tmp = d_row3[1];
+ d_row3[1] = d_row7[0];
+ d_row7[0] = d_tmp;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_half2(
+ half2* f_row0, half2* f_row1, half2* f_row2, half2* f_row3) {
+ half2 f_tmp;
+ f_tmp = f_row0[1];
+ f_row0[1] = f_row2[0];
+ f_row2[0] = f_tmp;
+
+ f_tmp = f_row1[1];
+ f_row1[1] = f_row3[0];
+ f_row3[0] = f_tmp;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose_half(half2& f0, half2& f1) {
+ __half a1 = get_half2_low(f0);
+ __half a2 = get_half2_high(f0);
+ __half b1 = get_half2_low(f1);
+ __half b2 = get_half2_high(f1);
+ f0 = combine_half(a1, b1);
+ f1 = combine_half(a2, b2);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet4h2,8>& kernel) {
+ double* d_row0 = reinterpret_cast<double*>(&kernel.packet[0]);
+ double* d_row1 = reinterpret_cast<double*>(&kernel.packet[1]);
+ double* d_row2 = reinterpret_cast<double*>(&kernel.packet[2]);
+ double* d_row3 = reinterpret_cast<double*>(&kernel.packet[3]);
+ double* d_row4 = reinterpret_cast<double*>(&kernel.packet[4]);
+ double* d_row5 = reinterpret_cast<double*>(&kernel.packet[5]);
+ double* d_row6 = reinterpret_cast<double*>(&kernel.packet[6]);
+ double* d_row7 = reinterpret_cast<double*>(&kernel.packet[7]);
+ ptranspose_double(d_row0, d_row1, d_row2, d_row3,
+ d_row4, d_row5, d_row6, d_row7);
+
+
+ half2* f_row0 = reinterpret_cast<half2*>(d_row0);
+ half2* f_row1 = reinterpret_cast<half2*>(d_row1);
+ half2* f_row2 = reinterpret_cast<half2*>(d_row2);
+ half2* f_row3 = reinterpret_cast<half2*>(d_row3);
+ ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
+ ptranspose_half(f_row0[0], f_row1[0]);
+ ptranspose_half(f_row0[1], f_row1[1]);
+ ptranspose_half(f_row2[0], f_row3[0]);
+ ptranspose_half(f_row2[1], f_row3[1]);
+
+ f_row0 = reinterpret_cast<half2*>(d_row0 + 1);
+ f_row1 = reinterpret_cast<half2*>(d_row1 + 1);
+ f_row2 = reinterpret_cast<half2*>(d_row2 + 1);
+ f_row3 = reinterpret_cast<half2*>(d_row3 + 1);
+ ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
+ ptranspose_half(f_row0[0], f_row1[0]);
+ ptranspose_half(f_row0[1], f_row1[1]);
+ ptranspose_half(f_row2[0], f_row3[0]);
+ ptranspose_half(f_row2[1], f_row3[1]);
+
+ f_row0 = reinterpret_cast<half2*>(d_row4);
+ f_row1 = reinterpret_cast<half2*>(d_row5);
+ f_row2 = reinterpret_cast<half2*>(d_row6);
+ f_row3 = reinterpret_cast<half2*>(d_row7);
+ ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
+ ptranspose_half(f_row0[0], f_row1[0]);
+ ptranspose_half(f_row0[1], f_row1[1]);
+ ptranspose_half(f_row2[0], f_row3[0]);
+ ptranspose_half(f_row2[1], f_row3[1]);
+
+ f_row0 = reinterpret_cast<half2*>(d_row4 + 1);
+ f_row1 = reinterpret_cast<half2*>(d_row5 + 1);
+ f_row2 = reinterpret_cast<half2*>(d_row6 + 1);
+ f_row3 = reinterpret_cast<half2*>(d_row7 + 1);
+ ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
+ ptranspose_half(f_row0[0], f_row1[0]);
+ ptranspose_half(f_row0[1], f_row1[1]);
+ ptranspose_half(f_row2[0], f_row3[0]);
+ ptranspose_half(f_row2[1], f_row3[1]);
+
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+plset<Packet4h2>(const Eigen::half& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = __halves2half2(a, __hadd(a, __float2half(1.0f)));
+ p_alias[1] = __halves2half2(__hadd(a, __float2half(2.0f)),
+ __hadd(a, __float2half(3.0f)));
+ p_alias[2] = __halves2half2(__hadd(a, __float2half(4.0f)),
+ __hadd(a, __float2half(5.0f)));
+ p_alias[3] = __halves2half2(__hadd(a, __float2half(6.0f)),
+ __hadd(a, __float2half(7.0f)));
+ return r;
+#elif defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+
+ half2 b = pset1<half2>(a);
+ half2 c;
+ half2 half_offset0 = __halves2half2(__float2half(0.0f),__float2half(2.0f));
+ half2 half_offset1 = __halves2half2(__float2half(4.0f),__float2half(6.0f));
+
+ c = __hadd2(b, half_offset0);
+ r_alias[0] = plset(__low2half(c));
+ r_alias[1] = plset(__high2half(c));
+
+ c = __hadd2(b, half_offset1);
+ r_alias[2] = plset(__low2half(c));
+ r_alias[3] = plset(__high2half(c));
+
+ return r;
+
+#else
+ float f = __half2float(a);
+ Packet4h2 r;
+ half2* p_alias = reinterpret_cast<half2*>(&r);
+ p_alias[0] = combine_half(a, __float2half(f + 1.0f));
+ p_alias[1] = combine_half(__float2half(f + 2.0f), __float2half(f + 3.0f));
+ p_alias[2] = combine_half(__float2half(f + 4.0f), __float2half(f + 5.0f));
+ p_alias[3] = combine_half(__float2half(f + 6.0f), __float2half(f + 7.0f));
+ return r;
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pselect<Packet4h2>(const Packet4h2& mask, const Packet4h2& a,
+ const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* mask_alias = reinterpret_cast<const half2*>(&mask);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pselect(mask_alias[0], a_alias[0], b_alias[0]);
+ r_alias[1] = pselect(mask_alias[1], a_alias[1], b_alias[1]);
+ r_alias[2] = pselect(mask_alias[2], a_alias[2], b_alias[2]);
+ r_alias[3] = pselect(mask_alias[3], a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pcmp_eq<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pcmp_eq(a_alias[0], b_alias[0]);
+ r_alias[1] = pcmp_eq(a_alias[1], b_alias[1]);
+ r_alias[2] = pcmp_eq(a_alias[2], b_alias[2]);
+ r_alias[3] = pcmp_eq(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pand(a_alias[0], b_alias[0]);
+ r_alias[1] = pand(a_alias[1], b_alias[1]);
+ r_alias[2] = pand(a_alias[2], b_alias[2]);
+ r_alias[3] = pand(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 por<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = por(a_alias[0], b_alias[0]);
+ r_alias[1] = por(a_alias[1], b_alias[1]);
+ r_alias[2] = por(a_alias[2], b_alias[2]);
+ r_alias[3] = por(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pxor<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pxor(a_alias[0], b_alias[0]);
+ r_alias[1] = pxor(a_alias[1], b_alias[1]);
+ r_alias[2] = pxor(a_alias[2], b_alias[2]);
+ r_alias[3] = pxor(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pandnot<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pandnot(a_alias[0], b_alias[0]);
+ r_alias[1] = pandnot(a_alias[1], b_alias[1]);
+ r_alias[2] = pandnot(a_alias[2], b_alias[2]);
+ r_alias[3] = pandnot(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 padd<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = padd(a_alias[0], b_alias[0]);
+ r_alias[1] = padd(a_alias[1], b_alias[1]);
+ r_alias[2] = padd(a_alias[2], b_alias[2]);
+ r_alias[3] = padd(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psub<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = psub(a_alias[0], b_alias[0]);
+ r_alias[1] = psub(a_alias[1], b_alias[1]);
+ r_alias[2] = psub(a_alias[2], b_alias[2]);
+ r_alias[3] = psub(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pnegate(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = pnegate(a_alias[0]);
+ r_alias[1] = pnegate(a_alias[1]);
+ r_alias[2] = pnegate(a_alias[2]);
+ r_alias[3] = pnegate(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pconj(const Packet4h2& a) {
+ return a;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmul<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pmul(a_alias[0], b_alias[0]);
+ r_alias[1] = pmul(a_alias[1], b_alias[1]);
+ r_alias[2] = pmul(a_alias[2], b_alias[2]);
+ r_alias[3] = pmul(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmadd<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b, const Packet4h2& c) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ const half2* c_alias = reinterpret_cast<const half2*>(&c);
+ r_alias[0] = pmadd(a_alias[0], b_alias[0], c_alias[0]);
+ r_alias[1] = pmadd(a_alias[1], b_alias[1], c_alias[1]);
+ r_alias[2] = pmadd(a_alias[2], b_alias[2], c_alias[2]);
+ r_alias[3] = pmadd(a_alias[3], b_alias[3], c_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pdiv<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pdiv(a_alias[0], b_alias[0]);
+ r_alias[1] = pdiv(a_alias[1], b_alias[1]);
+ r_alias[2] = pdiv(a_alias[2], b_alias[2]);
+ r_alias[3] = pdiv(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmin<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pmin(a_alias[0], b_alias[0]);
+ r_alias[1] = pmin(a_alias[1], b_alias[1]);
+ r_alias[2] = pmin(a_alias[2], b_alias[2]);
+ r_alias[3] = pmin(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmax<Packet4h2>(
+ const Packet4h2& a, const Packet4h2& b) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ const half2* b_alias = reinterpret_cast<const half2*>(&b);
+ r_alias[0] = pmax(a_alias[0], b_alias[0]);
+ r_alias[1] = pmax(a_alias[1], b_alias[1]);
+ r_alias[2] = pmax(a_alias[2], b_alias[2]);
+ r_alias[3] = pmax(a_alias[3], b_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux<Packet4h2>(
+ const Packet4h2& a) {
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+
+ return predux(a_alias[0]) + predux(a_alias[1]) +
+ predux(a_alias[2]) + predux(a_alias[3]);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4h2>(
+ const Packet4h2& a) {
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ half2 m0 = combine_half(predux_max(a_alias[0]),
+ predux_max(a_alias[1]));
+ half2 m1 = combine_half(predux_max(a_alias[2]),
+ predux_max(a_alias[3]));
+ __half first = predux_max(m0);
+ __half second = predux_max(m1);
+#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
+ return (__hgt(first, second) ? first : second);
+#else
+ float ffirst = __half2float(first);
+ float fsecond = __half2float(second);
+ return (ffirst > fsecond)? first: second;
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4h2>(
+ const Packet4h2& a) {
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ half2 m0 = combine_half(predux_min(a_alias[0]),
+ predux_min(a_alias[1]));
+ half2 m1 = combine_half(predux_min(a_alias[2]),
+ predux_min(a_alias[3]));
+ __half first = predux_min(m0);
+ __half second = predux_min(m1);
+#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
+ return (__hlt(first, second) ? first : second);
+#else
+ float ffirst = __half2float(first);
+ float fsecond = __half2float(second);
+ return (ffirst < fsecond)? first: second;
+#endif
+}
+
+// likely overflow/underflow
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4h2>(
+ const Packet4h2& a) {
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ return predux_mul(pmul(pmul(a_alias[0], a_alias[1]),
+ pmul(a_alias[2], a_alias[3])));
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+plog1p<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = plog1p(a_alias[0]);
+ r_alias[1] = plog1p(a_alias[1]);
+ r_alias[2] = plog1p(a_alias[2]);
+ r_alias[3] = plog1p(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+pexpm1<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = pexpm1(a_alias[0]);
+ r_alias[1] = pexpm1(a_alias[1]);
+ r_alias[2] = pexpm1(a_alias[2]);
+ r_alias[3] = pexpm1(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plog<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = plog(a_alias[0]);
+ r_alias[1] = plog(a_alias[1]);
+ r_alias[2] = plog(a_alias[2]);
+ r_alias[3] = plog(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pexp<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = pexp(a_alias[0]);
+ r_alias[1] = pexp(a_alias[1]);
+ r_alias[2] = pexp(a_alias[2]);
+ r_alias[3] = pexp(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psqrt<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = psqrt(a_alias[0]);
+ r_alias[1] = psqrt(a_alias[1]);
+ r_alias[2] = psqrt(a_alias[2]);
+ r_alias[3] = psqrt(a_alias[3]);
+ return r;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
+prsqrt<Packet4h2>(const Packet4h2& a) {
+ Packet4h2 r;
+ half2* r_alias = reinterpret_cast<half2*>(&r);
+ const half2* a_alias = reinterpret_cast<const half2*>(&a);
+ r_alias[0] = prsqrt(a_alias[0]);
+ r_alias[1] = prsqrt(a_alias[1]);
+ r_alias[2] = prsqrt(a_alias[2]);
+ r_alias[3] = prsqrt(a_alias[3]);
+ return r;
+}
+
+// The following specialized padd, pmul, pdiv, pmin, pmax, pset1 are needed for
+// the implementation of GPU half reduction.
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hadd2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 + b1;
+ float r2 = a2 + b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul<half2>(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __hmul2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 * b1;
+ float r2 = a2 * b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv<half2>(const half2& a,
+ const half2& b) {
+#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
+ return __h2div(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 / b1;
+ float r2 = a2 / b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin<half2>(const half2& a,
+ const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 < b1 ? get_half2_low(a) : get_half2_low(b);
+ __half r2 = a2 < b2 ? get_half2_high(a) : get_half2_high(b);
+ return combine_half(r1, r2);
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax<half2>(const half2& a,
+ const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 > b1 ? get_half2_low(a) : get_half2_low(b);
+ __half r2 = a2 > b2 ? get_half2_high(a) : get_half2_high(b);
+ return combine_half(r1, r2);
+}
+
+// #endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
+
+#endif // defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16)
+
+#undef EIGEN_GPU_HAS_LDG
+#undef EIGEN_CUDA_HAS_FP16_ARITHMETIC
+#undef EIGEN_GPU_HAS_FP16_ARITHMETIC
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+
+#endif // EIGEN_PACKET_MATH_GPU_H
diff --git a/Eigen/src/Core/arch/GPU/TypeCasting.h b/Eigen/src/Core/arch/GPU/TypeCasting.h
new file mode 100644
index 000000000..754546225
--- /dev/null
+++ b/Eigen/src/Core/arch/GPU/TypeCasting.h
@@ -0,0 +1,80 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_GPU_H
+#define EIGEN_TYPE_CASTING_GPU_H
+
+namespace Eigen {
+
+namespace internal {
+
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+
+
+template <>
+struct type_casting_traits<Eigen::half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 2
+ };
+};
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<half2, float4>(const half2& a, const half2& b) {
+ float2 r1 = __half22float2(a);
+ float2 r2 = __half22float2(b);
+ return make_float4(r1.x, r1.y, r2.x, r2.y);
+}
+
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcast<float4, Packet4h2>(const float4& a, const float4& b) {
+ Packet4h2 r;
+ half2* r_alias=reinterpret_cast<half2*>(&r);
+ r_alias[0]=__floats2half2_rn(a.x,a.y);
+ r_alias[1]=__floats2half2_rn(a.z,a.w);
+ r_alias[2]=__floats2half2_rn(b.x,b.y);
+ r_alias[3]=__floats2half2_rn(b.z,b.w);
+ return r;
+}
+
+template <>
+struct type_casting_traits<float, Eigen::half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 2,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<Packet4h2, float4>(const Packet4h2& a) {
+ // Simply discard the second half of the input
+ float4 r;
+ const half2* a_alias=reinterpret_cast<const half2*>(&a);
+ float2 r1 = __half22float2(a_alias[0]);
+ float2 r2 = __half22float2(a_alias[1]);
+ r.x=static_cast<float>(r1.x);
+ r.y=static_cast<float>(r1.y);
+ r.z=static_cast<float>(r2.x);
+ r.w=static_cast<float>(r2.y);
+ return r;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcast<float4, half2>(const float4& a) {
+ // Simply discard the second half of the input
+ return __floats2half2_rn(a.x, a.y);
+}
+
+#endif
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_GPU_H
diff --git a/Eigen/src/Core/arch/HIP/hcc/math_constants.h b/Eigen/src/Core/arch/HIP/hcc/math_constants.h
new file mode 100644
index 000000000..25375a0a4
--- /dev/null
+++ b/Eigen/src/Core/arch/HIP/hcc/math_constants.h
@@ -0,0 +1,23 @@
+/*
+ * math_constants.h -
+ * HIP equivalent of the CUDA header of the same name
+ */
+
+#ifndef __MATH_CONSTANTS_H__
+#define __MATH_CONSTANTS_H__
+
+/* single precision constants */
+
+#define HIPRT_INF_F __int_as_float(0x7f800000)
+#define HIPRT_NAN_F __int_as_float(0x7fffffff)
+#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001)
+#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff)
+#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000)
+#define HIPRT_ZERO_F 0.0f
+#define HIPRT_ONE_F 1.0f
+
+/* double precision constants */
+#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000)
+#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000)
+
+#endif
diff --git a/Eigen/src/Core/arch/MSA/Complex.h b/Eigen/src/Core/arch/MSA/Complex.h
new file mode 100644
index 000000000..53dacfa43
--- /dev/null
+++ b/Eigen/src/Core/arch/MSA/Complex.h
@@ -0,0 +1,648 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Wave Computing, Inc.
+// Written by:
+// Chris Larsen
+// Alexey Frunze (afrunze@wavecomp.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_COMPLEX_MSA_H
+#define EIGEN_COMPLEX_MSA_H
+
+#include <iostream>
+
+namespace Eigen {
+
+namespace internal {
+
+//---------- float ----------
+struct Packet2cf {
+ EIGEN_STRONG_INLINE Packet2cf() {
+ }
+ EIGEN_STRONG_INLINE explicit Packet2cf(const std::complex<float>& a,
+ const std::complex<float>& b) {
+ Packet4f t = { std::real(a), std::imag(a), std::real(b), std::imag(b) };
+ v = t;
+ }
+ EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {
+ }
+ EIGEN_STRONG_INLINE Packet2cf(const Packet2cf& a) : v(a.v) {
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator=(const Packet2cf& b) {
+ v = b.v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf conjugate(void) const {
+ return Packet2cf((Packet4f)__builtin_msa_bnegi_d((v2u64)v, 63));
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) {
+ Packet4f v1, v2;
+
+ // Get the real values of a | a1_re | a1_re | a2_re | a2_re |
+ v1 = (Packet4f)__builtin_msa_ilvev_w((v4i32)v, (v4i32)v);
+ // Get the imag values of a | a1_im | a1_im | a2_im | a2_im |
+ v2 = (Packet4f)__builtin_msa_ilvod_w((v4i32)v, (v4i32)v);
+ // Multiply the real a with b
+ v1 = pmul(v1, b.v);
+ // Multiply the imag a with b
+ v2 = pmul(v2, b.v);
+ // Conjugate v2
+ v2 = Packet2cf(v2).conjugate().v;
+ // Swap real/imag elements in v2.
+ v2 = (Packet4f)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(1, 0, 3, 2));
+ // Add and return the result
+ v = padd(v1, v2);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const {
+ return Packet2cf(*this) *= b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const {
+ return Packet2cf(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const {
+ return Packet2cf(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) {
+ *this *= b.conjugate();
+ Packet4f s = pmul<Packet4f>(b.v, b.v);
+ s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ v = pdiv(v, s);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const {
+ return Packet2cf(*this) /= b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(void) const {
+ return Packet2cf(pnegate(v));
+ }
+
+ Packet4f v;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const Packet2cf& value) {
+ os << "[ (" << value.v[0] << ", " << value.v[1]
+ << "i),"
+ " ("
+ << value.v[2] << ", " << value.v[3] << "i) ]";
+ return os;
+}
+
+template <>
+struct packet_traits<std::complex<float> > : default_packet_traits {
+ typedef Packet2cf type;
+ typedef Packet2cf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 2,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasSetLinear = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet2cf> {
+ typedef std::complex<float> type;
+ enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef Packet2cf half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from) {
+ EIGEN_MSA_DEBUG;
+
+ float f0 = from.real(), f1 = from.imag();
+ Packet4f v0 = { f0, f0, f0, f0 };
+ Packet4f v1 = { f1, f1, f1, f1 };
+ return Packet2cf((Packet4f)__builtin_msa_ilvr_w((Packet4i)v1, (Packet4i)v0));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a + b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a - b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return -a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a.conjugate();
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a * b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pand<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf(pand(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf por<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf(por(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pxor<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf(pxor(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf(pandnot(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) {
+ EIGEN_MSA_DEBUG;
+
+ return pset1<Packet2cf>(*from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<std::complex<float> >(std::complex<float>* to,
+ const Packet2cf& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_STORE pstore<float>((float*)to, from.v);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to,
+ const Packet2cf& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_STORE pstoreu<float>((float*)to, from.v);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(
+ const std::complex<float>* from, Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf(from[0 * stride], from[1 * stride]);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to,
+ const Packet2cf& from,
+ Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ *to = std::complex<float>(from.v[0], from.v[1]);
+ to += stride;
+ *to = std::complex<float>(from.v[2], from.v[3]);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float>* addr) {
+ EIGEN_MSA_DEBUG;
+
+ prefetch(reinterpret_cast<const float*>(addr));
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return std::complex<float>(a.v[0], a.v[1]);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4f value = (Packet4f)preverse((Packet2d)a.v);
+ value += a.v;
+ return std::complex<float>(value[0], value[1]);
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a) {
+ EIGEN_MSA_DEBUG;
+
+ return std::complex<float>((a.v[0] * a.v[2]) - (a.v[1] * a.v[3]),
+ (a.v[0] * a.v[3]) + (a.v[1] * a.v[2]));
+}
+
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf, Packet4f)
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a / b;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet2cf, 2>& value) {
+ os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]";
+ return os;
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cf, 2>& kernel) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4f tmp =
+ (Packet4f)__builtin_msa_ilvl_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v);
+ kernel.packet[0].v =
+ (Packet4f)__builtin_msa_ilvr_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v);
+ kernel.packet[1].v = tmp;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket,
+ const Packet2cf& elsePacket) {
+ return (Packet2cf)(Packet4f)pblend<Packet2d>(ifPacket, (Packet2d)thenPacket.v,
+ (Packet2d)elsePacket.v);
+}
+
+//---------- double ----------
+
+struct Packet1cd {
+ EIGEN_STRONG_INLINE Packet1cd() {
+ }
+ EIGEN_STRONG_INLINE explicit Packet1cd(const std::complex<double>& a) {
+ v[0] = std::real(a);
+ v[1] = std::imag(a);
+ }
+ EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {
+ }
+ EIGEN_STRONG_INLINE Packet1cd(const Packet1cd& a) : v(a.v) {
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator=(const Packet1cd& b) {
+ v = b.v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd conjugate(void) const {
+ static const v2u64 p2ul_CONJ_XOR = { 0x0, 0x8000000000000000 };
+ return (Packet1cd)pxor(v, (Packet2d)p2ul_CONJ_XOR);
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) {
+ Packet2d v1, v2;
+
+ // Get the real values of a | a1_re | a1_re
+ v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)v, (v2i64)v);
+ // Get the imag values of a | a1_im | a1_im
+ v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)v, (v2i64)v);
+ // Multiply the real a with b
+ v1 = pmul(v1, b.v);
+ // Multiply the imag a with b
+ v2 = pmul(v2, b.v);
+ // Conjugate v2
+ v2 = Packet1cd(v2).conjugate().v;
+ // Swap real/imag elements in v2.
+ v2 = (Packet2d)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+ // Add and return the result
+ v = padd(v1, v2);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const {
+ return Packet1cd(*this) *= b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const {
+ return Packet1cd(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const {
+ return Packet1cd(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator/=(const Packet1cd& b) {
+ *this *= b.conjugate();
+ Packet2d s = pmul<Packet2d>(b.v, b.v);
+ s = padd(s, preverse<Packet2d>(s));
+ v = pdiv(v, s);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator/(const Packet1cd& b) const {
+ return Packet1cd(*this) /= b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(void) const {
+ return Packet1cd(pnegate(v));
+ }
+
+ Packet2d v;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const Packet1cd& value) {
+ os << "[ (" << value.v[0] << ", " << value.v[1] << "i) ]";
+ return os;
+}
+
+template <>
+struct packet_traits<std::complex<double> > : default_packet_traits {
+ typedef Packet1cd type;
+ typedef Packet1cd half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 0,
+ size = 1,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasSetLinear = 0
+ };
+};
+
+template <>
+struct unpacket_traits<Packet1cd> {
+ typedef std::complex<double> type;
+ enum { size = 1, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef Packet1cd half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a + b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a - b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return -a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a.conjugate();
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a * b;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pand<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(pand(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd por<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(por(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pxor<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(pxor(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(pandnot(a.v, b.v));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from) {
+ EIGEN_MSA_DEBUG;
+
+ return pset1<Packet1cd>(*from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<std::complex<double> >(std::complex<double>* to,
+ const Packet1cd& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_STORE pstore<double>((double*)to, from.v);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double>* to,
+ const Packet1cd& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_STORE pstoreu<double>((double*)to, from.v);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double>* addr) {
+ EIGEN_MSA_DEBUG;
+
+ prefetch(reinterpret_cast<const double*>(addr));
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(
+ const std::complex<double>* from, Index stride __attribute__((unused))) {
+ EIGEN_MSA_DEBUG;
+
+ Packet1cd res;
+ res.v[0] = std::real(from[0]);
+ res.v[1] = std::imag(from[0]);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to,
+ const Packet1cd& from,
+ Index stride
+ __attribute__((unused))) {
+ EIGEN_MSA_DEBUG;
+
+ pstore(to, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return std::complex<double>(a.v[0], a.v[1]);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return pfirst(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) {
+ EIGEN_MSA_DEBUG;
+
+ return pfirst(a);
+}
+
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd, Packet2d)
+
+template <>
+EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b) {
+ EIGEN_MSA_DEBUG;
+
+ return a / b;
+}
+
+EIGEN_STRONG_INLINE Packet1cd pcplxflip /*<Packet1cd>*/ (const Packet1cd& x) {
+ EIGEN_MSA_DEBUG;
+
+ return Packet1cd(preverse(Packet2d(x.v)));
+}
+
+inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet1cd, 2>& value) {
+ os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]";
+ return os;
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd, 2>& kernel) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d v1, v2;
+
+ v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v);
+ // Get the imag values of a
+ v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v);
+
+ kernel.packet[0].v = v1;
+ kernel.packet[1].v = v2;
+}
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_COMPLEX_MSA_H
diff --git a/Eigen/src/Core/arch/MSA/MathFunctions.h b/Eigen/src/Core/arch/MSA/MathFunctions.h
new file mode 100644
index 000000000..f5181b90e
--- /dev/null
+++ b/Eigen/src/Core/arch/MSA/MathFunctions.h
@@ -0,0 +1,387 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2007 Julien Pommier
+// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
+// Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// Copyright (C) 2018 Wave Computing, Inc.
+// Written by:
+// Chris Larsen
+// Alexey Frunze (afrunze@wavecomp.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/* The sin, cos, exp, and log functions of this file come from
+ * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
+ */
+
+/* The tanh function of this file is an adaptation of
+ * template<typename T> T generic_fast_tanh_float(const T&)
+ * from MathFunctionsImpl.h.
+ */
+
+#ifndef EIGEN_MATH_FUNCTIONS_MSA_H
+#define EIGEN_MATH_FUNCTIONS_MSA_H
+
+namespace Eigen {
+
+namespace internal {
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+plog<Packet4f>(const Packet4f& _x) {
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292e-2f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, -1.1514610310e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, -1.2420140846e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, +1.4249322787e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, -1.6668057665e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, +2.0000714765e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, -2.4999993993e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, +3.3333331174e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
+ static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
+ static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
+
+ // Convert negative argument into NAN (quiet negative, to be specific).
+ Packet4f zero = (Packet4f)__builtin_msa_ldi_w(0);
+ Packet4i neg_mask = __builtin_msa_fclt_w(_x, zero);
+ Packet4i zero_mask = __builtin_msa_fceq_w(_x, zero);
+ Packet4f non_neg_x_or_nan = padd(_x, (Packet4f)neg_mask); // Add 0.0 or NAN.
+ Packet4f x = non_neg_x_or_nan;
+
+ // Extract exponent from x = mantissa * 2**exponent, where 1.0 <= mantissa < 2.0.
+ // N.B. the exponent is one less of what frexpf() would return.
+ Packet4i e_int = __builtin_msa_ftint_s_w(__builtin_msa_flog2_w(x));
+ // Multiply x by 2**(-exponent-1) to get 0.5 <= x < 1.0 as from frexpf().
+ x = __builtin_msa_fexp2_w(x, (Packet4i)__builtin_msa_nori_b((v16u8)e_int, 0));
+
+ /*
+ if (x < SQRTHF) {
+ x = x + x - 1.0;
+ } else {
+ e += 1;
+ x = x - 1.0;
+ }
+ */
+ Packet4f xx = padd(x, x);
+ Packet4i ge_mask = __builtin_msa_fcle_w(p4f_cephes_SQRTHF, x);
+ e_int = psub(e_int, ge_mask);
+ x = (Packet4f)__builtin_msa_bsel_v((v16u8)ge_mask, (v16u8)xx, (v16u8)x);
+ x = psub(x, p4f_1);
+ Packet4f e = __builtin_msa_ffint_s_w(e_int);
+
+ Packet4f x2 = pmul(x, x);
+ Packet4f x3 = pmul(x2, x);
+
+ Packet4f y, y1, y2;
+ y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
+ y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
+ y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
+ y = pmadd(y, x, p4f_cephes_log_p2);
+ y1 = pmadd(y1, x, p4f_cephes_log_p5);
+ y2 = pmadd(y2, x, p4f_cephes_log_p8);
+ y = pmadd(y, x3, y1);
+ y = pmadd(y, x3, y2);
+ y = pmul(y, x3);
+
+ y = pmadd(e, p4f_cephes_log_q1, y);
+ x = __builtin_msa_fmsub_w(x, x2, p4f_half);
+ x = padd(x, y);
+ x = pmadd(e, p4f_cephes_log_q2, x);
+
+ // x is now the logarithm result candidate. We still need to handle the
+ // extreme arguments of zero and positive infinity, though.
+ // N.B. if the argument is +INFINITY, x is NAN because the polynomial terms
+ // contain infinities of both signs (see the coefficients and code above).
+ // INFINITY - INFINITY is NAN.
+
+ // If the argument is +INFINITY, make it the new result candidate.
+ // To achieve that we choose the smaller of the result candidate and the
+ // argument.
+ // This is correct for all finite pairs of values (the logarithm is smaller
+ // than the argument).
+ // This is also correct in the special case when the argument is +INFINITY
+ // and the result candidate is NAN. This is because the fmin.df instruction
+ // prefers non-NANs to NANs.
+ x = __builtin_msa_fmin_w(x, non_neg_x_or_nan);
+
+ // If the argument is zero (including -0.0), the result becomes -INFINITY.
+ Packet4i neg_infs = __builtin_msa_slli_w(zero_mask, 23);
+ x = (Packet4f)__builtin_msa_bsel_v((v16u8)zero_mask, (v16u8)x, (v16u8)neg_infs);
+
+ return x;
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+pexp<Packet4f>(const Packet4f& _x) {
+ // Limiting single-precision pexp's argument to [-128, +128] lets pexp
+ // reach 0 and INFINITY naturally.
+ static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -128.0f);
+ static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, +128.0f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894e-2f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
+ static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
+
+ Packet4f x = _x;
+
+ // Clamp x.
+ x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(x, p4f_exp_lo), (v16u8)x,
+ (v16u8)p4f_exp_lo);
+ x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_exp_hi, x), (v16u8)x,
+ (v16u8)p4f_exp_hi);
+
+ // Round to nearest integer by adding 0.5 (with x's sign) and truncating.
+ Packet4f x2_add = (Packet4f)__builtin_msa_binsli_w((v4u32)p4f_half, (v4u32)x, 0);
+ Packet4f x2 = pmadd(x, p4f_cephes_LOG2EF, x2_add);
+ Packet4i x2_int = __builtin_msa_ftrunc_s_w(x2);
+ Packet4f x2_int_f = __builtin_msa_ffint_s_w(x2_int);
+
+ x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C1);
+ x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C2);
+
+ Packet4f z = pmul(x, x);
+
+ Packet4f y = p4f_cephes_exp_p0;
+ y = pmadd(y, x, p4f_cephes_exp_p1);
+ y = pmadd(y, x, p4f_cephes_exp_p2);
+ y = pmadd(y, x, p4f_cephes_exp_p3);
+ y = pmadd(y, x, p4f_cephes_exp_p4);
+ y = pmadd(y, x, p4f_cephes_exp_p5);
+ y = pmadd(y, z, x);
+ y = padd(y, p4f_1);
+
+ // y *= 2**exponent.
+ y = __builtin_msa_fexp2_w(y, x2_int);
+
+ return y;
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+ptanh<Packet4f>(const Packet4f& _x) {
+ static _EIGEN_DECLARE_CONST_Packet4f(tanh_tiny, 1e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(tanh_hi, 9.0f);
+ // The monomial coefficients of the numerator polynomial (odd).
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_1, 4.89352455891786e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_3, 6.37261928875436e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_5, 1.48572235717979e-5f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_7, 5.12229709037114e-8f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_9, -8.60467152213735e-11f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_11, 2.00018790482477e-13f);
+ static _EIGEN_DECLARE_CONST_Packet4f(alpha_13, -2.76076847742355e-16f);
+ // The monomial coefficients of the denominator polynomial (even).
+ static _EIGEN_DECLARE_CONST_Packet4f(beta_0, 4.89352518554385e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(beta_2, 2.26843463243900e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(beta_4, 1.18534705686654e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(beta_6, 1.19825839466702e-6f);
+
+ Packet4f x = pabs(_x);
+ Packet4i tiny_mask = __builtin_msa_fclt_w(x, p4f_tanh_tiny);
+
+ // Clamp the inputs to the range [-9, 9] since anything outside
+ // this range is -/+1.0f in single-precision.
+ x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_tanh_hi, x), (v16u8)x,
+ (v16u8)p4f_tanh_hi);
+
+ // Since the polynomials are odd/even, we need x**2.
+ Packet4f x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial p.
+ Packet4f p = pmadd(x2, p4f_alpha_13, p4f_alpha_11);
+ p = pmadd(x2, p, p4f_alpha_9);
+ p = pmadd(x2, p, p4f_alpha_7);
+ p = pmadd(x2, p, p4f_alpha_5);
+ p = pmadd(x2, p, p4f_alpha_3);
+ p = pmadd(x2, p, p4f_alpha_1);
+ p = pmul(x, p);
+
+ // Evaluate the denominator polynomial q.
+ Packet4f q = pmadd(x2, p4f_beta_6, p4f_beta_4);
+ q = pmadd(x2, q, p4f_beta_2);
+ q = pmadd(x2, q, p4f_beta_0);
+
+ // Divide the numerator by the denominator.
+ p = pdiv(p, q);
+
+ // Reinstate the sign.
+ p = (Packet4f)__builtin_msa_binsli_w((v4u32)p, (v4u32)_x, 0);
+
+ // When the argument is very small in magnitude it's more accurate to just return it.
+ p = (Packet4f)__builtin_msa_bsel_v((v16u8)tiny_mask, (v16u8)p, (v16u8)_x);
+
+ return p;
+}
+
+template <bool sine>
+Packet4f psincos_inner_msa_float(const Packet4f& _x) {
+ static _EIGEN_DECLARE_CONST_Packet4f(sincos_max_arg, 13176795.0f); // Approx. (2**24) / (4/Pi).
+ static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1, -0.78515625f);
+ static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
+ static _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891e-4f);
+ static _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611e-1f);
+ static _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948e-5f);
+ static _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765e-3f);
+ static _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827e-2f);
+ static _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4/Pi.
+ static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
+ static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f);
+
+ Packet4f x = pabs(_x);
+
+ // Translate infinite arguments into NANs.
+ Packet4f zero_or_nan_if_inf = psub(_x, _x);
+ x = padd(x, zero_or_nan_if_inf);
+ // Prevent sin/cos from generating values larger than 1.0 in magnitude
+ // for very large arguments by setting x to 0.0.
+ Packet4i small_or_nan_mask = __builtin_msa_fcult_w(x, p4f_sincos_max_arg);
+ x = pand(x, (Packet4f)small_or_nan_mask);
+
+ // Scale x by 4/Pi to find x's octant.
+ Packet4f y = pmul(x, p4f_cephes_FOPI);
+ // Get the octant. We'll reduce x by this number of octants or by one more than it.
+ Packet4i y_int = __builtin_msa_ftrunc_s_w(y);
+ // x's from even-numbered octants will translate to octant 0: [0, +Pi/4].
+ // x's from odd-numbered octants will translate to octant -1: [-Pi/4, 0].
+ // Adjustment for odd-numbered octants: octant = (octant + 1) & (~1).
+ Packet4i y_int1 = __builtin_msa_addvi_w(y_int, 1);
+ Packet4i y_int2 = (Packet4i)__builtin_msa_bclri_w((Packet4ui)y_int1, 0); // bclri = bit-clear
+ y = __builtin_msa_ffint_s_w(y_int2);
+
+ // Compute the sign to apply to the polynomial.
+ Packet4i sign_mask = sine ? pxor(__builtin_msa_slli_w(y_int1, 29), (Packet4i)_x)
+ : __builtin_msa_slli_w(__builtin_msa_addvi_w(y_int, 3), 29);
+
+ // Get the polynomial selection mask.
+ // We'll calculate both (sin and cos) polynomials and then select from the two.
+ Packet4i poly_mask = __builtin_msa_ceqi_w(__builtin_msa_slli_w(y_int2, 30), 0);
+
+ // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4.
+ // The magic pass: "Extended precision modular arithmetic"
+ // x = ((x - y * DP1) - y * DP2) - y * DP3
+ Packet4f tmp1 = pmul(y, p4f_minus_cephes_DP1);
+ Packet4f tmp2 = pmul(y, p4f_minus_cephes_DP2);
+ Packet4f tmp3 = pmul(y, p4f_minus_cephes_DP3);
+ x = padd(x, tmp1);
+ x = padd(x, tmp2);
+ x = padd(x, tmp3);
+
+ // Evaluate the cos(x) polynomial.
+ y = p4f_coscof_p0;
+ Packet4f z = pmul(x, x);
+ y = pmadd(y, z, p4f_coscof_p1);
+ y = pmadd(y, z, p4f_coscof_p2);
+ y = pmul(y, z);
+ y = pmul(y, z);
+ y = __builtin_msa_fmsub_w(y, z, p4f_half);
+ y = padd(y, p4f_1);
+
+ // Evaluate the sin(x) polynomial.
+ Packet4f y2 = p4f_sincof_p0;
+ y2 = pmadd(y2, z, p4f_sincof_p1);
+ y2 = pmadd(y2, z, p4f_sincof_p2);
+ y2 = pmul(y2, z);
+ y2 = pmadd(y2, x, x);
+
+ // Select the correct result from the two polynomials.
+ y = sine ? (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y, (v16u8)y2)
+ : (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y2, (v16u8)y);
+
+ // Update the sign.
+ sign_mask = pxor(sign_mask, (Packet4i)y);
+ y = (Packet4f)__builtin_msa_binsli_w((v4u32)y, (v4u32)sign_mask, 0); // binsli = bit-insert-left
+ return y;
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+psin<Packet4f>(const Packet4f& x) {
+ return psincos_inner_msa_float</* sine */ true>(x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+pcos<Packet4f>(const Packet4f& x) {
+ return psincos_inner_msa_float</* sine */ false>(x);
+}
+
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d
+pexp<Packet2d>(const Packet2d& _x) {
+ // Limiting double-precision pexp's argument to [-1024, +1024] lets pexp
+ // reach 0 and INFINITY naturally.
+ static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -1024.0);
+ static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, +1024.0);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
+ static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
+ static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
+ static _EIGEN_DECLARE_CONST_Packet2d(1, 1.0);
+ static _EIGEN_DECLARE_CONST_Packet2d(2, 2.0);
+
+ Packet2d x = _x;
+
+ // Clamp x.
+ x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(x, p2d_exp_lo), (v16u8)x,
+ (v16u8)p2d_exp_lo);
+ x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(p2d_exp_hi, x), (v16u8)x,
+ (v16u8)p2d_exp_hi);
+
+ // Round to nearest integer by adding 0.5 (with x's sign) and truncating.
+ Packet2d x2_add = (Packet2d)__builtin_msa_binsli_d((v2u64)p2d_half, (v2u64)x, 0);
+ Packet2d x2 = pmadd(x, p2d_cephes_LOG2EF, x2_add);
+ Packet2l x2_long = __builtin_msa_ftrunc_s_d(x2);
+ Packet2d x2_long_d = __builtin_msa_ffint_s_d(x2_long);
+
+ x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C1);
+ x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C2);
+
+ x2 = pmul(x, x);
+
+ Packet2d px = p2d_cephes_exp_p0;
+ px = pmadd(px, x2, p2d_cephes_exp_p1);
+ px = pmadd(px, x2, p2d_cephes_exp_p2);
+ px = pmul(px, x);
+
+ Packet2d qx = p2d_cephes_exp_q0;
+ qx = pmadd(qx, x2, p2d_cephes_exp_q1);
+ qx = pmadd(qx, x2, p2d_cephes_exp_q2);
+ qx = pmadd(qx, x2, p2d_cephes_exp_q3);
+
+ x = pdiv(px, psub(qx, px));
+ x = pmadd(p2d_2, x, p2d_1);
+
+ // x *= 2**exponent.
+ x = __builtin_msa_fexp2_d(x, x2_long);
+
+ return x;
+}
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATH_FUNCTIONS_MSA_H
diff --git a/Eigen/src/Core/arch/MSA/PacketMath.h b/Eigen/src/Core/arch/MSA/PacketMath.h
new file mode 100644
index 000000000..afe8f3375
--- /dev/null
+++ b/Eigen/src/Core/arch/MSA/PacketMath.h
@@ -0,0 +1,1233 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Wave Computing, Inc.
+// Written by:
+// Chris Larsen
+// Alexey Frunze (afrunze@wavecomp.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_PACKET_MATH_MSA_H
+#define EIGEN_PACKET_MATH_MSA_H
+
+#include <iostream>
+#include <string>
+
+namespace Eigen {
+
+namespace internal {
+
+#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
+#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
+#endif
+
+#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+#endif
+
+#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
+#endif
+
+#if 0
+#define EIGEN_MSA_DEBUG \
+ static bool firstTime = true; \
+ do { \
+ if (firstTime) { \
+ std::cout << __FILE__ << ':' << __LINE__ << ':' << __FUNCTION__ << std::endl; \
+ firstTime = false; \
+ } \
+ } while (0)
+#else
+#define EIGEN_MSA_DEBUG
+#endif
+
+#define EIGEN_MSA_SHF_I8(a, b, c, d) (((d) << 6) | ((c) << 4) | ((b) << 2) | (a))
+
+typedef v4f32 Packet4f;
+typedef v4i32 Packet4i;
+typedef v4u32 Packet4ui;
+
+#define _EIGEN_DECLARE_CONST_Packet4f(NAME, X) const Packet4f p4f_##NAME = { X, X, X, X }
+#define _EIGEN_DECLARE_CONST_Packet4i(NAME, X) const Packet4i p4i_##NAME = { X, X, X, X }
+#define _EIGEN_DECLARE_CONST_Packet4ui(NAME, X) const Packet4ui p4ui_##NAME = { X, X, X, X }
+
+inline std::ostream& operator<<(std::ostream& os, const Packet4f& value) {
+ os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Packet4i& value) {
+ os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Packet4ui& value) {
+ os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]";
+ return os;
+}
+
+template <>
+struct packet_traits<float> : default_packet_traits {
+ typedef Packet4f type;
+ typedef Packet4f half; // Packet2f intrinsics not implemented yet
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0, // Packet2f intrinsics not implemented yet
+ // FIXME check the Has*
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<int32_t> : default_packet_traits {
+ typedef Packet4i type;
+ typedef Packet4i half; // Packet2i intrinsics not implemented yet
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0, // Packet2i intrinsics not implemented yet
+ // FIXME check the Has*
+ HasDiv = 1,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet4f> {
+ typedef float type;
+ enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef Packet4f half;
+};
+
+template <>
+struct unpacket_traits<Packet4i> {
+ typedef int32_t type;
+ enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef Packet4i half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4f v = { from, from, from, from };
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int32_t& from) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fill_w(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pload1<Packet4f>(const float* from) {
+ EIGEN_MSA_DEBUG;
+
+ float f = *from;
+ Packet4f v = { f, f, f, f };
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pload1<Packet4i>(const int32_t* from) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fill_w(*from);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fadd_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_addv_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) {
+ EIGEN_MSA_DEBUG;
+
+ static const Packet4f countdown = { 0.0f, 1.0f, 2.0f, 3.0f };
+ return padd(pset1<Packet4f>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int32_t& a) {
+ EIGEN_MSA_DEBUG;
+
+ static const Packet4i countdown = { 0, 1, 2, 3 };
+ return padd(pset1<Packet4i>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fsub_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_subv_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_bnegi_w((v4u32)a, 31);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_addvi_w((v4i32)__builtin_msa_nori_b((v16u8)a, 0), 1);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fmul_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_mulv_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fdiv_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_div_s_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fmadd_w(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) {
+ EIGEN_MSA_DEBUG;
+
+ // Use "asm" construct to avoid __builtin_msa_maddv_w GNU C bug.
+ Packet4i value = c;
+ __asm__("maddv.w %w[value], %w[a], %w[b]\n"
+ // Outputs
+ : [value] "+f"(value)
+ // Inputs
+ : [a] "f"(a), [b] "f"(b));
+ return value;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_and_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4i)__builtin_msa_and_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_or_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4i)__builtin_msa_or_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_xor_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4i)__builtin_msa_xor_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+ return pand(a, (Packet4f)__builtin_msa_xori_b((v16u8)b, 255));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return pand(a, (Packet4i)__builtin_msa_xori_b((v16u8)b, 255));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ // This prefers numbers to NaNs.
+ return __builtin_msa_fmin_w(a, b);
+#else
+ // This prefers NaNs to numbers.
+ Packet4i aNaN = __builtin_msa_fcun_w(a, a);
+ Packet4i aMinOrNaN = por(__builtin_msa_fclt_w(a, b), aNaN);
+ return (Packet4f)__builtin_msa_bsel_v((v16u8)aMinOrNaN, (v16u8)b, (v16u8)a);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_min_s_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ // This prefers numbers to NaNs.
+ return __builtin_msa_fmax_w(a, b);
+#else
+ // This prefers NaNs to numbers.
+ Packet4i aNaN = __builtin_msa_fcun_w(a, a);
+ Packet4i aMaxOrNaN = por(__builtin_msa_fclt_w(b, a), aNaN);
+ return (Packet4f)__builtin_msa_bsel_v((v16u8)aMaxOrNaN, (v16u8)b, (v16u8)a);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_max_s_w(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_LOAD return (Packet4f)__builtin_msa_ld_w(const_cast<float*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_LOAD return __builtin_msa_ld_w(const_cast<int32_t*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return (Packet4f)__builtin_msa_ld_w(const_cast<float*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int32_t* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return (Packet4i)__builtin_msa_ld_w(const_cast<int32_t*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from) {
+ EIGEN_MSA_DEBUG;
+
+ float f0 = from[0], f1 = from[1];
+ Packet4f v0 = { f0, f0, f0, f0 };
+ Packet4f v1 = { f1, f1, f1, f1 };
+ return (Packet4f)__builtin_msa_ilvr_d((v2i64)v1, (v2i64)v0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int32_t* from) {
+ EIGEN_MSA_DEBUG;
+
+ int32_t i0 = from[0], i1 = from[1];
+ Packet4i v0 = { i0, i0, i0, i0 };
+ Packet4i v1 = { i1, i1, i1, i1 };
+ return (Packet4i)__builtin_msa_ilvr_d((v2i64)v1, (v2i64)v0);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_w((Packet4i)from, to, 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_w(from, to, 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_w((Packet4i)from, to, 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet4i& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_w(from, to, 0);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ float f = *from;
+ Packet4f v = { f, f, f, f };
+ v[1] = from[stride];
+ v[2] = from[2 * stride];
+ v[3] = from[3 * stride];
+ return v;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet4i pgather<int32_t, Packet4i>(const int32_t* from, Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ int32_t i = *from;
+ Packet4i v = { i, i, i, i };
+ v[1] = from[stride];
+ v[2] = from[2 * stride];
+ v[3] = from[3 * stride];
+ return v;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from,
+ Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ *to = from[0];
+ to += stride;
+ *to = from[1];
+ to += stride;
+ *to = from[2];
+ to += stride;
+ *to = from[3];
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<int32_t, Packet4i>(int32_t* to, const Packet4i& from,
+ Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ *to = from[0];
+ to += stride;
+ *to = from[1];
+ to += stride;
+ *to = from[2];
+ to += stride;
+ *to = from[3];
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) {
+ EIGEN_MSA_DEBUG;
+
+ __builtin_prefetch(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<int32_t>(const int32_t* addr) {
+ EIGEN_MSA_DEBUG;
+
+ __builtin_prefetch(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE int32_t pfirst<Packet4i>(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(3, 2, 1, 0));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(3, 2, 1, 0));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet4f)__builtin_msa_bclri_w((v4u32)a, 31);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4i zero = __builtin_msa_ldi_w(0);
+ return __builtin_msa_add_a_w(zero, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4f s = padd(a, (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return s[0];
+}
+
+
+template <>
+EIGEN_STRONG_INLINE int32_t predux<Packet4i>(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4i s = padd(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ s = padd(s, __builtin_msa_shf_w(s, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return s[0];
+}
+
+// Other reduction functions:
+// mul
+template <>
+EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4f p = pmul(a, (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ p = pmul(p, (Packet4f)__builtin_msa_shf_w((v4i32)p, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return p[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE int32_t predux_mul<Packet4i>(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4i p = pmul(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ p = pmul(p, __builtin_msa_shf_w(p, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return p[0];
+}
+
+// min
+template <>
+EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ // Swap 64-bit halves of a.
+ Packet4f swapped = (Packet4f)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+#if !EIGEN_FAST_MATH
+ // Detect presence of NaNs from pairs a[0]-a[2] and a[1]-a[3] as two 32-bit
+ // masks of all zeroes/ones in low 64 bits.
+ v16u8 unord = (v16u8)__builtin_msa_fcun_w(a, swapped);
+ // Combine the two masks into one: 64 ones if no NaNs, otherwise 64 zeroes.
+ unord = (v16u8)__builtin_msa_ceqi_d((v2i64)unord, 0);
+#endif
+ // Continue with min computation.
+ Packet4f v = __builtin_msa_fmin_w(a, swapped);
+ v = __builtin_msa_fmin_w(
+ v, (Packet4f)__builtin_msa_shf_w((Packet4i)v, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+#if !EIGEN_FAST_MATH
+ // Based on the mask select between v and 4 qNaNs.
+ v16u8 qnans = (v16u8)__builtin_msa_fill_w(0x7FC00000);
+ v = (Packet4f)__builtin_msa_bsel_v(unord, qnans, (v16u8)v);
+#endif
+ return v[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE int32_t predux_min<Packet4i>(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4i m = pmin(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ m = pmin(m, __builtin_msa_shf_w(m, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return m[0];
+}
+
+// max
+template <>
+EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ // Swap 64-bit halves of a.
+ Packet4f swapped = (Packet4f)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+#if !EIGEN_FAST_MATH
+ // Detect presence of NaNs from pairs a[0]-a[2] and a[1]-a[3] as two 32-bit
+ // masks of all zeroes/ones in low 64 bits.
+ v16u8 unord = (v16u8)__builtin_msa_fcun_w(a, swapped);
+ // Combine the two masks into one: 64 ones if no NaNs, otherwise 64 zeroes.
+ unord = (v16u8)__builtin_msa_ceqi_d((v2i64)unord, 0);
+#endif
+ // Continue with max computation.
+ Packet4f v = __builtin_msa_fmax_w(a, swapped);
+ v = __builtin_msa_fmax_w(
+ v, (Packet4f)__builtin_msa_shf_w((Packet4i)v, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+#if !EIGEN_FAST_MATH
+ // Based on the mask select between v and 4 qNaNs.
+ v16u8 qnans = (v16u8)__builtin_msa_fill_w(0x7FC00000);
+ v = (Packet4f)__builtin_msa_bsel_v(unord, qnans, (v16u8)v);
+#endif
+ return v[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE int32_t predux_max<Packet4i>(const Packet4i& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet4i m = pmax(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)));
+ m = pmax(m, __builtin_msa_shf_w(m, EIGEN_MSA_SHF_I8(1, 0, 3, 2)));
+ return m[0];
+}
+
+inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet4f, 4>& value) {
+ os << "[ " << value.packet[0] << "," << std::endl
+ << " " << value.packet[1] << "," << std::endl
+ << " " << value.packet[2] << "," << std::endl
+ << " " << value.packet[3] << " ]";
+ return os;
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4f, 4>& kernel) {
+ EIGEN_MSA_DEBUG;
+
+ v4i32 tmp1, tmp2, tmp3, tmp4;
+
+ tmp1 = __builtin_msa_ilvr_w((v4i32)kernel.packet[1], (v4i32)kernel.packet[0]);
+ tmp2 = __builtin_msa_ilvr_w((v4i32)kernel.packet[3], (v4i32)kernel.packet[2]);
+ tmp3 = __builtin_msa_ilvl_w((v4i32)kernel.packet[1], (v4i32)kernel.packet[0]);
+ tmp4 = __builtin_msa_ilvl_w((v4i32)kernel.packet[3], (v4i32)kernel.packet[2]);
+
+ kernel.packet[0] = (Packet4f)__builtin_msa_ilvr_d((v2i64)tmp2, (v2i64)tmp1);
+ kernel.packet[1] = (Packet4f)__builtin_msa_ilvod_d((v2i64)tmp2, (v2i64)tmp1);
+ kernel.packet[2] = (Packet4f)__builtin_msa_ilvr_d((v2i64)tmp4, (v2i64)tmp3);
+ kernel.packet[3] = (Packet4f)__builtin_msa_ilvod_d((v2i64)tmp4, (v2i64)tmp3);
+}
+
+inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet4i, 4>& value) {
+ os << "[ " << value.packet[0] << "," << std::endl
+ << " " << value.packet[1] << "," << std::endl
+ << " " << value.packet[2] << "," << std::endl
+ << " " << value.packet[3] << " ]";
+ return os;
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4i, 4>& kernel) {
+ EIGEN_MSA_DEBUG;
+
+ v4i32 tmp1, tmp2, tmp3, tmp4;
+
+ tmp1 = __builtin_msa_ilvr_w(kernel.packet[1], kernel.packet[0]);
+ tmp2 = __builtin_msa_ilvr_w(kernel.packet[3], kernel.packet[2]);
+ tmp3 = __builtin_msa_ilvl_w(kernel.packet[1], kernel.packet[0]);
+ tmp4 = __builtin_msa_ilvl_w(kernel.packet[3], kernel.packet[2]);
+
+ kernel.packet[0] = (Packet4i)__builtin_msa_ilvr_d((v2i64)tmp2, (v2i64)tmp1);
+ kernel.packet[1] = (Packet4i)__builtin_msa_ilvod_d((v2i64)tmp2, (v2i64)tmp1);
+ kernel.packet[2] = (Packet4i)__builtin_msa_ilvr_d((v2i64)tmp4, (v2i64)tmp3);
+ kernel.packet[3] = (Packet4i)__builtin_msa_ilvod_d((v2i64)tmp4, (v2i64)tmp3);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fsqrt_w(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ return __builtin_msa_frsqrt_w(a);
+#else
+ Packet4f ones = __builtin_msa_ffint_s_w(__builtin_msa_ldi_w(1));
+ return pdiv(ones, psqrt(a));
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) {
+ Packet4f v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n" // 3 = round towards -INFINITY.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.w %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) {
+ Packet4f v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n"
+ "xori %[new_mode], %[new_mode], 1\n" // 2 = round towards +INFINITY.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.w %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
+ Packet4f v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n"
+ "xori %[new_mode], %[new_mode], 3\n" // 0 = round to nearest, ties to even.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.w %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket,
+ const Packet4f& elsePacket) {
+ Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2],
+ ifPacket.select[3] };
+ Packet4i mask = __builtin_msa_ceqi_w((Packet4i)select, 0);
+ return (Packet4f)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket,
+ const Packet4i& elsePacket) {
+ Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2],
+ ifPacket.select[3] };
+ Packet4i mask = __builtin_msa_ceqi_w((Packet4i)select, 0);
+ return (Packet4i)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket);
+}
+
+//---------- double ----------
+
+typedef v2f64 Packet2d;
+typedef v2i64 Packet2l;
+typedef v2u64 Packet2ul;
+
+#define _EIGEN_DECLARE_CONST_Packet2d(NAME, X) const Packet2d p2d_##NAME = { X, X }
+#define _EIGEN_DECLARE_CONST_Packet2l(NAME, X) const Packet2l p2l_##NAME = { X, X }
+#define _EIGEN_DECLARE_CONST_Packet2ul(NAME, X) const Packet2ul p2ul_##NAME = { X, X }
+
+inline std::ostream& operator<<(std::ostream& os, const Packet2d& value) {
+ os << "[ " << value[0] << ", " << value[1] << " ]";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Packet2l& value) {
+ os << "[ " << value[0] << ", " << value[1] << " ]";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Packet2ul& value) {
+ os << "[ " << value[0] << ", " << value[1] << " ]";
+ return os;
+}
+
+template <>
+struct packet_traits<double> : default_packet_traits {
+ typedef Packet2d type;
+ typedef Packet2d half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 2,
+ HasHalfPacket = 0,
+ // FIXME check the Has*
+ HasDiv = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet2d> {
+ typedef double type;
+ enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef Packet2d half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d value = { from, from };
+ return value;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fadd_d(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) {
+ EIGEN_MSA_DEBUG;
+
+ static const Packet2d countdown = { 0.0, 1.0 };
+ return padd(pset1<Packet2d>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fsub_d(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_bnegi_d((v2u64)a, 63);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pmul<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fmul_d(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fdiv_d(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fmadd_d(c, a, b);
+}
+
+// Logical Operations are not supported for float, so we have to reinterpret casts using MSA
+// intrinsics
+template <>
+EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_and_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_or_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_xor_v((v16u8)a, (v16u8)b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+ return pand(a, (Packet2d)__builtin_msa_xori_b((v16u8)b, 255));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return (Packet2d)__builtin_msa_ld_d(const_cast<double*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ // This prefers numbers to NaNs.
+ return __builtin_msa_fmin_d(a, b);
+#else
+ // This prefers NaNs to numbers.
+ v2i64 aNaN = __builtin_msa_fcun_d(a, a);
+ v2i64 aMinOrNaN = por(__builtin_msa_fclt_d(a, b), aNaN);
+ return (Packet2d)__builtin_msa_bsel_v((v16u8)aMinOrNaN, (v16u8)b, (v16u8)a);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ // This prefers numbers to NaNs.
+ return __builtin_msa_fmax_d(a, b);
+#else
+ // This prefers NaNs to numbers.
+ v2i64 aNaN = __builtin_msa_fcun_d(a, a);
+ v2i64 aMaxOrNaN = por(__builtin_msa_fclt_d(b, a), aNaN);
+ return (Packet2d)__builtin_msa_bsel_v((v16u8)aMaxOrNaN, (v16u8)b, (v16u8)a);
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_LOAD return (Packet2d)__builtin_msa_ld_d(const_cast<double*>(from), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d value = { *from, *from };
+ return value;
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_d((v2i64)from, to, 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) {
+ EIGEN_MSA_DEBUG;
+
+ EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_d((v2i64)from, to, 0);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d value;
+ value[0] = *from;
+ from += stride;
+ value[1] = *from;
+ return value;
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from,
+ Index stride) {
+ EIGEN_MSA_DEBUG;
+
+ *to = from[0];
+ to += stride;
+ *to = from[1];
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) {
+ EIGEN_MSA_DEBUG;
+
+ __builtin_prefetch(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return a[0];
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return (Packet2d)__builtin_msa_bclri_d((v2u64)a, 63);
+}
+
+template <>
+EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d s = padd(a, preverse(a));
+ return s[0];
+}
+
+// Other reduction functions:
+// mul
+template <>
+EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d p = pmul(a, preverse(a));
+ return p[0];
+}
+
+// min
+template <>
+EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ Packet2d swapped = (Packet2d)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+ Packet2d v = __builtin_msa_fmin_d(a, swapped);
+ return v[0];
+#else
+ double a0 = a[0], a1 = a[1];
+ return ((numext::isnan)(a0) || a0 < a1) ? a0 : a1;
+#endif
+}
+
+// max
+template <>
+EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ Packet2d swapped = (Packet2d)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1));
+ Packet2d v = __builtin_msa_fmax_d(a, swapped);
+ return v[0];
+#else
+ double a0 = a[0], a1 = a[1];
+ return ((numext::isnan)(a0) || a0 > a1) ? a0 : a1;
+#endif
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+ return __builtin_msa_fsqrt_d(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
+ EIGEN_MSA_DEBUG;
+
+#if EIGEN_FAST_MATH
+ return __builtin_msa_frsqrt_d(a);
+#else
+ Packet2d ones = __builtin_msa_ffint_s_d(__builtin_msa_ldi_d(1));
+ return pdiv(ones, psqrt(a));
+#endif
+}
+
+inline std::ostream& operator<<(std::ostream& os, const PacketBlock<Packet2d, 2>& value) {
+ os << "[ " << value.packet[0] << "," << std::endl << " " << value.packet[1] << " ]";
+ return os;
+}
+
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2d, 2>& kernel) {
+ EIGEN_MSA_DEBUG;
+
+ Packet2d trn1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)kernel.packet[1], (v2i64)kernel.packet[0]);
+ Packet2d trn2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)kernel.packet[1], (v2i64)kernel.packet[0]);
+ kernel.packet[0] = trn1;
+ kernel.packet[1] = trn2;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) {
+ Packet2d v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n" // 3 = round towards -INFINITY.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.d %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) {
+ Packet2d v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n"
+ "xori %[new_mode], %[new_mode], 1\n" // 2 = round towards +INFINITY.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.d %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) {
+ Packet2d v = a;
+ int32_t old_mode, new_mode;
+ asm volatile(
+ "cfcmsa %[old_mode], $1\n"
+ "ori %[new_mode], %[old_mode], 3\n"
+ "xori %[new_mode], %[new_mode], 3\n" // 0 = round to nearest, ties to even.
+ "ctcmsa $1, %[new_mode]\n"
+ "frint.d %w[v], %w[v]\n"
+ "ctcmsa $1, %[old_mode]\n"
+ : // outputs
+ [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode),
+ [v] "+f"(v)
+ : // inputs
+ : // clobbers
+ );
+ return v;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket,
+ const Packet2d& elsePacket) {
+ Packet2ul select = { ifPacket.select[0], ifPacket.select[1] };
+ Packet2l mask = __builtin_msa_ceqi_d((Packet2l)select, 0);
+ return (Packet2d)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket);
+}
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_PACKET_MATH_MSA_H
diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h
index 57e9b431f..f40af7f87 100644
--- a/Eigen/src/Core/arch/NEON/Complex.h
+++ b/Eigen/src/Core/arch/NEON/Complex.h
@@ -15,9 +15,10 @@ namespace Eigen {
namespace internal {
-inline uint32x4_t p4ui_CONJ_XOR() {
+inline uint32x4_t p4ui_CONJ_XOR()
+{
// See bug 1325, clang fails to call vld1q_u64.
-#if EIGEN_COMP_CLANG
+#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML
uint32x4_t ret = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 };
return ret;
#else
@@ -26,61 +27,136 @@ inline uint32x4_t p4ui_CONJ_XOR() {
#endif
}
-inline uint32x2_t p2ui_CONJ_XOR() {
+inline uint32x2_t p2ui_CONJ_XOR()
+{
static const uint32_t conj_XOR_DATA[] = { 0x00000000, 0x80000000 };
return vld1_u32( conj_XOR_DATA );
}
//---------- float ----------
+
+struct Packet1cf
+{
+ EIGEN_STRONG_INLINE Packet1cf() {}
+ EIGEN_STRONG_INLINE explicit Packet1cf(const Packet2f& a) : v(a) {}
+ Packet2f v;
+};
struct Packet2cf
{
EIGEN_STRONG_INLINE Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {}
- Packet4f v;
+ Packet4f v;
};
-template<> struct packet_traits<std::complex<float> > : default_packet_traits
+template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet2cf type;
- typedef Packet2cf half;
- enum {
+ typedef Packet1cf half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 2,
- HasHalfPacket = 0,
-
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasDiv = 1,
- HasNegate = 1,
- HasAbs = 0,
- HasAbs2 = 0,
- HasMin = 0,
- HasMax = 0,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
HasSetLinear = 0
};
};
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
-
-template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
+template<> struct unpacket_traits<Packet1cf>
{
- float32x2_t r64;
- r64 = vld1_f32((float *)&from);
+ typedef std::complex<float> type;
+ typedef Packet1cf half;
+ typedef Packet2f as_real;
+ enum
+ {
+ size = 1,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2cf>
+{
+ typedef std::complex<float> type;
+ typedef Packet1cf half;
+ typedef Packet4f as_real;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> EIGEN_STRONG_INLINE Packet1cf pcast<float,Packet1cf>(const float& a)
+{ return Packet1cf(vset_lane_f32(a, vdup_n_f32(0.f), 0)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pcast<Packet2f,Packet2cf>(const Packet2f& a)
+{ return Packet2cf(vreinterpretq_f32_u64(vmovl_u32(vreinterpret_u32_f32(a)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pset1<Packet1cf>(const std::complex<float>& from)
+{ return Packet1cf(vld1_f32(reinterpret_cast<const float*>(&from))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
+{
+ const float32x2_t r64 = vld1_f32(reinterpret_cast<const float*>(&from));
return Packet2cf(vcombine_f32(r64, r64));
}
-template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(padd<Packet4f>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(psub<Packet4f>(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cf padd<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(padd<Packet2f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(padd<Packet4f>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf psub<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(psub<Packet2f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(psub<Packet4f>(a.v, b.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pnegate(const Packet1cf& a) { return Packet1cf(pnegate<Packet2f>(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate<Packet4f>(a.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pconj(const Packet1cf& a)
+{
+ const Packet2ui b = vreinterpret_u32_f32(a.v);
+ return Packet1cf(vreinterpret_f32_u32(veor_u32(b, p2ui_CONJ_XOR())));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a)
{
- Packet4ui b = vreinterpretq_u32_f32(a.v);
+ const Packet4ui b = vreinterpretq_u32_f32(a.v);
return Packet2cf(vreinterpretq_f32_u32(veorq_u32(b, p4ui_CONJ_XOR())));
}
+template<> EIGEN_STRONG_INLINE Packet1cf pmul<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{
+ Packet2f v1, v2;
+
+ // Get the real values of a | a1_re | a1_re |
+ v1 = vdup_lane_f32(a.v, 0);
+ // Get the imag values of a | a1_im | a1_im |
+ v2 = vdup_lane_f32(a.v, 1);
+ // Multiply the real a with b
+ v1 = vmul_f32(v1, b.v);
+ // Multiply the imag a with b
+ v2 = vmul_f32(v2, b.v);
+ // Conjugate v2
+ v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR()));
+ // Swap real/imag elements in v2.
+ v2 = vrev64_f32(v2);
+ // Add and return the result
+ return Packet1cf(vadd_f32(v1, v2));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
Packet4f v1, v2;
@@ -93,7 +169,7 @@ template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, con
v1 = vmulq_f32(v1, b.v);
// Multiply the imag a with b
v2 = vmulq_f32(v2, b.v);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(v2), p4ui_CONJ_XOR()));
// Swap real/imag elements in v2.
v2 = vrev64q_f32(v2);
@@ -101,98 +177,144 @@ template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, con
return Packet2cf(vaddq_f32(v1, v2));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- return Packet2cf(vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
-}
-template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+template<> EIGEN_STRONG_INLINE Packet1cf pcmp_eq(const Packet1cf& a, const Packet1cf& b)
{
- return Packet2cf(vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a[0])==re(b[0]), im(a[0])==im(b[0])]
+ Packet2f eq = pcmp_eq<Packet2f>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a[0])==im(b[0]), re(a[0])==re(b[0])]
+ Packet2f eq_swapped = vrev64_f32(eq);
+ // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cf(pand<Packet2f>(eq, eq_swapped));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- return Packet2cf(vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
-}
-template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- return Packet2cf(vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.v),vreinterpretq_u32_f32(b.v))));
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b)
+{
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a[0])==re(b[0]), im(a[0])==im(b[0]), re(a[1])==re(b[1]), im(a[1])==im(b[1])]
+ Packet4f eq = pcmp_eq<Packet4f>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a[0])==im(b[0]), re(a[0])==re(b[0]), im(a[1])==im(b[1]), re(a[1])==re(b[1])]
+ Packet4f eq_swapped = vrev64q_f32(eq);
+ // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet2cf(pand<Packet4f>(eq, eq_swapped));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from)); }
-template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet1cf pand<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pand<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
-template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
+template<> EIGEN_STRONG_INLINE Packet1cf por<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf por<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
-template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
-template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE Packet1cf pxor<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pxor<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
-template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
+template<> EIGEN_STRONG_INLINE Packet1cf pandnot<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
+{ return Packet1cf(vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{ return Packet2cf(vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf pload<Packet1cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cf(pload<Packet2f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>(reinterpret_cast<const float*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf ploadu<Packet1cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cf(ploadu<Packet2f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>(reinterpret_cast<const float*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cf ploaddup<Packet1cf>(const std::complex<float>* from)
+{ return pset1<Packet1cf>(*from); }
+template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from)
+{ return pset1<Packet2cf>(*from); }
+
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> *to, const Packet1cf& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> *to, const Packet2cf& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<float*>(to), from.v); }
+
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> *to, const Packet1cf& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> *to, const Packet2cf& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast<float*>(to), from.v); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet1cf pgather<std::complex<float>, Packet1cf>(
+ const std::complex<float>* from, Index stride)
+{
+ const Packet2f tmp = vdup_n_f32(std::real(from[0*stride]));
+ return Packet1cf(vset_lane_f32(std::imag(from[0*stride]), tmp, 1));
+}
+template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(
+ const std::complex<float>* from, Index stride)
{
- Packet4f res = pset1<Packet4f>(0.f);
- res = vsetq_lane_f32(std::real(from[0*stride]), res, 0);
+ Packet4f res = vdupq_n_f32(std::real(from[0*stride]));
res = vsetq_lane_f32(std::imag(from[0*stride]), res, 1);
res = vsetq_lane_f32(std::real(from[1*stride]), res, 2);
res = vsetq_lane_f32(std::imag(from[1*stride]), res, 3);
return Packet2cf(res);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet1cf>(
+ std::complex<float>* to, const Packet1cf& from, Index stride)
+{ to[stride*0] = std::complex<float>(vget_lane_f32(from.v, 0), vget_lane_f32(from.v, 1)); }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(
+ std::complex<float>* to, const Packet2cf& from, Index stride)
{
to[stride*0] = std::complex<float>(vgetq_lane_f32(from.v, 0), vgetq_lane_f32(from.v, 1));
to[stride*1] = std::complex<float>(vgetq_lane_f32(from.v, 2), vgetq_lane_f32(from.v, 3));
}
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { EIGEN_ARM_PREFETCH((float *)addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> *addr)
+{ EIGEN_ARM_PREFETCH(reinterpret_cast<const float*>(addr)); }
-template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet1cf>(const Packet1cf& a)
+{
+ EIGEN_ALIGN16 std::complex<float> x;
+ vst1_f32(reinterpret_cast<float*>(&x), a.v);
+ return x;
+}
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
{
- std::complex<float> EIGEN_ALIGN16 x[2];
- vst1q_f32((float *)x, a.v);
+ EIGEN_ALIGN16 std::complex<float> x[2];
+ vst1q_f32(reinterpret_cast<float*>(x), a.v);
return x[0];
}
+template<> EIGEN_STRONG_INLINE Packet1cf preverse(const Packet1cf& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
-{
- float32x2_t a_lo, a_hi;
- Packet4f a_r128;
-
- a_lo = vget_low_f32(a.v);
- a_hi = vget_high_f32(a.v);
- a_r128 = vcombine_f32(a_hi, a_lo);
-
- return Packet2cf(a_r128);
-}
+{ return Packet2cf(vcombine_f32(vget_high_f32(a.v), vget_low_f32(a.v))); }
+template<> EIGEN_STRONG_INLINE Packet1cf pcplxflip<Packet1cf>(const Packet1cf& a)
+{ return Packet1cf(vrev64_f32(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& a)
+{ return Packet2cf(vrev64q_f32(a.v)); }
+
+template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet1cf>(const Packet1cf& a)
{
- return Packet2cf(vrev64q_f32(a.v));
+ std::complex<float> s;
+ vst1_f32((float *)&s, a.v);
+ return s;
}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a)
{
- float32x2_t a1, a2;
std::complex<float> s;
-
- a1 = vget_low_f32(a.v);
- a2 = vget_high_f32(a.v);
- a2 = vadd_f32(a1, a2);
- vst1_f32((float *)&s, a2);
-
+ vst1_f32(reinterpret_cast<float*>(&s), vadd_f32(vget_low_f32(a.v), vget_high_f32(a.v)));
return s;
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
+template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet1cf>(const Packet1cf& a)
{
- Packet4f sum1, sum2, sum;
-
- // Add the first two 64-bit float32x2_t of vecs[0]
- sum1 = vcombine_f32(vget_low_f32(vecs[0].v), vget_low_f32(vecs[1].v));
- sum2 = vcombine_f32(vget_high_f32(vecs[0].v), vget_high_f32(vecs[1].v));
- sum = vaddq_f32(sum1, sum2);
-
- return Packet2cf(sum);
+ std::complex<float> s;
+ vst1_f32((float *)&s, a.v);
+ return s;
}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
float32x2_t a1, a2, v1, v2, prod;
@@ -208,88 +330,67 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const P
v1 = vmul_f32(v1, a2);
// Multiply the imag a with b
v2 = vmul_f32(v2, a2);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR()));
// Swap real/imag elements in v2.
v2 = vrev64_f32(v2);
// Add v1, v2
prod = vadd_f32(v1, v2);
- vst1_f32((float *)&s, prod);
+ vst1_f32(reinterpret_cast<float*>(&s), prod);
return s;
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- EIGEN_STRONG_INLINE static void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
- first.v = vextq_f32(first.v, second.v, 2);
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cf,Packet2f)
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
+template<> EIGEN_STRONG_INLINE Packet1cf pdiv<Packet1cf>(const Packet1cf& a, const Packet1cf& b)
{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
+ // TODO optimize it for NEON
+ Packet1cf res = pmul(a, pconj(b));
+ Packet2f s, rev_s;
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
+ // this computes the norm
+ s = vmul_f32(b.v, b.v);
+ rev_s = vrev64_f32(s);
+ return Packet1cf(pdiv<Packet2f>(res.v, vadd_f32(s, rev_s)));
+}
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for NEON
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a,b);
+ Packet2cf res = pmul(a,pconj(b));
Packet4f s, rev_s;
// this computes the norm
s = vmulq_f32(b.v, b.v);
rev_s = vrev64q_f32(s);
- return Packet2cf(pdiv(res.v, vaddq_f32(s,rev_s)));
+ return Packet2cf(pdiv<Packet4f>(res.v, vaddq_f32(s, rev_s)));
}
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet2cf,2>& kernel) {
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet1cf, 1>& /*kernel*/) {}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet2cf, 2>& kernel)
+{
Packet4f tmp = vcombine_f32(vget_high_f32(kernel.packet[0].v), vget_high_f32(kernel.packet[1].v));
kernel.packet[0].v = vcombine_f32(vget_low_f32(kernel.packet[0].v), vget_low_f32(kernel.packet[1].v));
kernel.packet[1].v = tmp;
}
+template<> EIGEN_STRONG_INLINE Packet1cf psqrt<Packet1cf>(const Packet1cf& a) {
+ return psqrt_complex<Packet1cf>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
+ return psqrt_complex<Packet2cf>(a);
+}
+
//---------- double ----------
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
// See bug 1325, clang fails to call vld1q_u64.
-#if EIGEN_COMP_CLANG
+#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML
static uint64x2_t p2ul_CONJ_XOR = {0x0, 0x8000000000000000};
#else
const uint64_t p2ul_conj_XOR_DATA[] = { 0x0, 0x8000000000000000 };
@@ -307,7 +408,8 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet1cd type;
typedef Packet1cd half;
- enum {
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 0,
size = 1,
@@ -326,24 +428,50 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd>
+{
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet2d as_real;
+ enum
+ {
+ size=1,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>(reinterpret_cast<const double*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>(reinterpret_cast<const double*>(from))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
+{
+ /* here we really have to use unaligned loads :( */
+ return ploadu<Packet1cd>(&from);
+}
-template<> EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from)); }
-template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from)); }
+template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(padd<Packet2d>(a.v, b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
-{ /* here we really have to use unaligned loads :( */ return ploadu<Packet1cd>(&from); }
+template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(psub<Packet2d>(a.v, b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(padd<Packet2d>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(psub<Packet2d>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate<Packet2d>(a.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v), p2ul_CONJ_XOR))); }
+template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a)
+{ return Packet1cd(pnegate<Packet2d>(a.v)); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a)
+{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v), p2ul_CONJ_XOR))); }
template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
Packet2d v1, v2;
- // Get the real values of a
+ // Get the real values of a
v1 = vdupq_lane_f64(vget_low_f64(a.v), 0);
// Get the imag values of a
v2 = vdupq_lane_f64(vget_high_f64(a.v), 0);
@@ -351,7 +479,7 @@ template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, con
v1 = vmulq_f64(v1, b.v);
// Multiply the imag a with b
v2 = vmulq_f64(v2, b.v);
- // Conjugate v2
+ // Conjugate v2
v2 = vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(v2), p2ul_CONJ_XOR));
// Swap real/imag elements in v2.
v2 = preverse<Packet2d>(v2);
@@ -359,31 +487,44 @@ template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, con
return Packet1cd(vaddq_f64(v1, v2));
}
-template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
-template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
-template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b)
{
- return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a)==re(b), im(a)==im(b)]
+ Packet2d eq = pcmp_eq<Packet2d>(a.v, b.v);
+ // Swap real/imag elements in the mask in to get:
+ // [im(a)==im(b), re(a)==re(b)]
+ Packet2d eq_swapped = vreinterpretq_f64_u32(vrev64q_u32(vreinterpretq_u32_f64(eq)));
+ // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cd(pand<Packet2d>(eq, eq_swapped));
}
+
+template<> EIGEN_STRONG_INLINE Packet1cd pand<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd por<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
+template<> EIGEN_STRONG_INLINE Packet1cd pxor<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
+
template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- return Packet1cd(vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v))));
-}
+{ return Packet1cd(vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); }
-template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from) { return pset1<Packet1cd>(*from); }
+template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from)
+{ return pset1<Packet1cd>(*from); }
-template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
-template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> *to, const Packet1cd& from)
+{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<double*>(to), from.v); }
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> * addr) { EIGEN_ARM_PREFETCH((double *)addr); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> *to, const Packet1cd& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast<double*>(to), from.v); }
-template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index stride)
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> *addr)
+{ EIGEN_ARM_PREFETCH(reinterpret_cast<const double*>(addr)); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(
+ const std::complex<double>* from, Index stride)
{
Packet2d res = pset1<Packet2d>(0.0);
res = vsetq_lane_f64(std::real(from[0*stride]), res, 0);
@@ -391,17 +532,14 @@ template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Pack
return Packet1cd(res);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index stride)
-{
- to[stride*0] = std::complex<double>(vgetq_lane_f64(from.v, 0), vgetq_lane_f64(from.v, 1));
-}
-
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(
+ std::complex<double>* to, const Packet1cd& from, Index stride)
+{ to[stride*0] = std::complex<double>(vgetq_lane_f64(from.v, 0), vgetq_lane_f64(from.v, 1)); }
-template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
+template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
- std::complex<double> EIGEN_ALIGN16 res;
+ EIGEN_ALIGN16 std::complex<double> res;
pstore<std::complex<double> >(&res, a);
-
return res;
}
@@ -409,57 +547,14 @@ template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a
template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs) { return vecs[0]; }
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for NEON
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
Packet2d rev_s = preverse<Packet2d>(s);
@@ -467,9 +562,7 @@ template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, con
}
EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
-{
- return Packet1cd(preverse(Packet2d(x.v)));
-}
+{ return Packet1cd(preverse(Packet2d(x.v))); }
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
{
@@ -477,6 +570,11 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
kernel.packet[0].v = vcombine_f64(vget_low_f64(kernel.packet[0].v), vget_low_f64(kernel.packet[1].v));
kernel.packet[1].v = tmp;
}
+
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
+ return psqrt_complex<Packet1cd>(a);
+}
+
#endif // EIGEN_ARCH_ARM64
} // end namespace internal
diff --git a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
new file mode 100644
index 000000000..ee8089997
--- /dev/null
+++ b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
@@ -0,0 +1,183 @@
+namespace Eigen {
+namespace internal {
+
+#if EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
+
+// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm.
+// Here we specialize gebp_traits to eliminate these register spills.
+// See #2138.
+template<>
+struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
+ : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
+{
+ EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
+ {
+ // This volatile inline ASM both acts as a barrier to prevent reordering,
+ // as well as enforces strict register use.
+ asm volatile(
+ "vmla.f32 %q[r], %q[c], %q[alpha]"
+ : [r] "+w" (r)
+ : [c] "w" (c),
+ [alpha] "w" (alpha)
+ : );
+ }
+
+ template <typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const Packet4f& a, const Packet4f& b,
+ Packet4f& c, [[maybe_unused]] Packet4f& tmp,
+ const LaneIdType&) const {
+ acc(a, b, c);
+ }
+
+ template <typename LaneIdType>
+ EIGEN_STRONG_INLINE void madd(const Packet4f& a, const QuadPacket<Packet4f>& b,
+ Packet4f& c, Packet4f& tmp,
+ const LaneIdType& lane) const {
+ madd(a, b.get(lane), c, tmp, lane);
+ }
+};
+
+#endif // EIGEN_ARCH_ARM && EIGEN_COMP_CLANG
+
+#if EIGEN_ARCH_ARM64
+
+template<>
+struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
+ : gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
+{
+ typedef float RhsPacket;
+ typedef float32x4_t RhsPacketx4;
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ dest = vld1q_f32(b);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
+
+ EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfmaq_n_f32(c, a, b);
+ }
+
+ // NOTE: Template parameter inference failed when compiled with Android NDK:
+ // "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ { madd_helper<0>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
+ { madd_helper<1>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
+ { madd_helper<2>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
+ { madd_helper<3>(a, b, c); }
+
+ private:
+ template<int LaneID>
+ EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
+ {
+ #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f32 is implemented through a costly dup
+ if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==3) asm("fmla %0.4s, %1.4s, %2.s[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ #else
+ c = vfmaq_laneq_f32(c, a, b, LaneID);
+ #endif
+ }
+};
+
+
+template<>
+struct gebp_traits <double,double,false,false,Architecture::NEON>
+ : gebp_traits<double,double,false,false,Architecture::Generic>
+{
+ typedef double RhsPacket;
+
+ struct RhsPacketx4 {
+ float64x2_t B_0, B_1;
+ };
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ dest.B_0 = vld1q_f64(b);
+ dest.B_1 = vld1q_f64(b+2);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
+
+ EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfmaq_n_f64(c, a, b);
+ }
+
+ // NOTE: Template parameter inference failed when compiled with Android NDK:
+ // "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ { madd_helper<0>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
+ { madd_helper<1>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
+ { madd_helper<2>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
+ { madd_helper<3>(a, b, c); }
+
+ private:
+ template <int LaneID>
+ EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
+ {
+ #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f64 is implemented through a costly dup
+ if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
+ else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
+ else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
+ else if(LaneID==3) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
+ #else
+ if(LaneID==0) c = vfmaq_laneq_f64(c, a, b.B_0, 0);
+ else if(LaneID==1) c = vfmaq_laneq_f64(c, a, b.B_0, 1);
+ else if(LaneID==2) c = vfmaq_laneq_f64(c, a, b.B_1, 0);
+ else if(LaneID==3) c = vfmaq_laneq_f64(c, a, b.B_1, 1);
+ #endif
+ }
+};
+
+#endif // EIGEN_ARCH_ARM64
+
+} // namespace internal
+} // namespace Eigen
diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h
index 6bb05bb92..fa6615a85 100644
--- a/Eigen/src/Core/arch/NEON/MathFunctions.h
+++ b/Eigen/src/Core/arch/NEON/MathFunctions.h
@@ -5,10 +5,6 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
- * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
- */
-
#ifndef EIGEN_MATH_FUNCTIONS_NEON_H
#define EIGEN_MATH_FUNCTIONS_NEON_H
@@ -16,74 +12,62 @@ namespace Eigen {
namespace internal {
-template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f pexp<Packet4f>(const Packet4f& _x)
-{
- Packet4f x = _x;
- Packet4f tmp, fx;
-
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
-
- x = vminq_f32(x, p4f_exp_hi);
- x = vmaxq_f32(x, p4f_exp_lo);
-
- /* express exp(x) as exp(g + n*log(2)) */
- fx = vmlaq_f32(p4f_half, x, p4f_cephes_LOG2EF);
-
- /* perform a floorf */
- tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
-
- /* if greater, substract 1 */
- Packet4ui mask = vcgtq_f32(tmp, fx);
- mask = vandq_u32(mask, vreinterpretq_u32_f32(p4f_1));
-
- fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
-
- tmp = vmulq_f32(fx, p4f_cephes_exp_C1);
- Packet4f z = vmulq_f32(fx, p4f_cephes_exp_C2);
- x = vsubq_f32(x, tmp);
- x = vsubq_f32(x, z);
-
- Packet4f y = vmulq_f32(p4f_cephes_exp_p0, x);
- z = vmulq_f32(x, x);
- y = vaddq_f32(y, p4f_cephes_exp_p1);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p2);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p3);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p4);
- y = vmulq_f32(y, x);
- y = vaddq_f32(y, p4f_cephes_exp_p5);
-
- y = vmulq_f32(y, z);
- y = vaddq_f32(y, x);
- y = vaddq_f32(y, p4f_1);
-
- /* build 2^n */
- int32x4_t mm;
- mm = vcvtq_s32_f32(fx);
- mm = vaddq_s32(mm, p4i_0x7f);
- mm = vshlq_n_s32(mm, 23);
- Packet4f pow2n = vreinterpretq_f32_s32(mm);
-
- y = vmulq_f32(y, pow2n);
- return y;
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pexp<Packet2f>(const Packet2f& x)
+{ return pexp_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pexp<Packet4f>(const Packet4f& x)
+{ return pexp_float(x); }
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f plog<Packet2f>(const Packet2f& x)
+{ return plog_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f plog<Packet4f>(const Packet4f& x)
+{ return plog_float(x); }
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f psin<Packet2f>(const Packet2f& x)
+{ return psin_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f psin<Packet4f>(const Packet4f& x)
+{ return psin_float(x); }
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pcos<Packet2f>(const Packet2f& x)
+{ return pcos_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pcos<Packet4f>(const Packet4f& x)
+{ return pcos_float(x); }
+
+// Hyperbolic Tangent function.
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f ptanh<Packet2f>(const Packet2f& x)
+{ return internal::generic_fast_tanh_float(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
+{ return internal::generic_fast_tanh_float(x); }
+
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
+BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
+
+template <>
+EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {
+ Packet4f fexponent;
+ const Packet4bf out = F32ToBf16(pfrexp<Packet4f>(Bf16ToF32(a), fexponent));
+ exponent = F32ToBf16(fexponent);
+ return out;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) {
+ return F32ToBf16(pldexp<Packet4f>(Bf16ToF32(a), Bf16ToF32(exponent)));
}
+//---------- double ----------
+
+#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d pexp<Packet2d>(const Packet2d& x)
+{ return pexp_double(x); }
+
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d plog<Packet2d>(const Packet2d& x)
+{ return plog_double(x); }
+
+#endif
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 836fbc0dd..d2aeef430 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -24,23 +24,118 @@ namespace internal {
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
-#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#endif
-
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#if EIGEN_ARCH_ARM64
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#else
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
#endif
#endif
-typedef float32x2_t Packet2f;
-typedef float32x4_t Packet4f;
-typedef int32x4_t Packet4i;
-typedef int32x2_t Packet2i;
-typedef uint32x4_t Packet4ui;
+#if EIGEN_COMP_MSVC_STRICT
+
+// In MSVC's arm_neon.h header file, all NEON vector types
+// are aliases to the same underlying type __n128.
+// We thus have to wrap them to make them different C++ types.
+// (See also bug 1428)
+typedef eigen_packet_wrapper<float32x2_t,0> Packet2f;
+typedef eigen_packet_wrapper<float32x4_t,1> Packet4f;
+typedef eigen_packet_wrapper<int32_t ,2> Packet4c;
+typedef eigen_packet_wrapper<int8x8_t ,3> Packet8c;
+typedef eigen_packet_wrapper<int8x16_t ,4> Packet16c;
+typedef eigen_packet_wrapper<uint32_t ,5> Packet4uc;
+typedef eigen_packet_wrapper<uint8x8_t ,6> Packet8uc;
+typedef eigen_packet_wrapper<uint8x16_t ,7> Packet16uc;
+typedef eigen_packet_wrapper<int16x4_t ,8> Packet4s;
+typedef eigen_packet_wrapper<int16x8_t ,9> Packet8s;
+typedef eigen_packet_wrapper<uint16x4_t ,10> Packet4us;
+typedef eigen_packet_wrapper<uint16x8_t ,11> Packet8us;
+typedef eigen_packet_wrapper<int32x2_t ,12> Packet2i;
+typedef eigen_packet_wrapper<int32x4_t ,13> Packet4i;
+typedef eigen_packet_wrapper<uint32x2_t ,14> Packet2ui;
+typedef eigen_packet_wrapper<uint32x4_t ,15> Packet4ui;
+typedef eigen_packet_wrapper<int64x2_t ,16> Packet2l;
+typedef eigen_packet_wrapper<uint64x2_t ,17> Packet2ul;
+
+#else
+
+typedef float32x2_t Packet2f;
+typedef float32x4_t Packet4f;
+typedef eigen_packet_wrapper<int32_t ,2> Packet4c;
+typedef int8x8_t Packet8c;
+typedef int8x16_t Packet16c;
+typedef eigen_packet_wrapper<uint32_t ,5> Packet4uc;
+typedef uint8x8_t Packet8uc;
+typedef uint8x16_t Packet16uc;
+typedef int16x4_t Packet4s;
+typedef int16x8_t Packet8s;
+typedef uint16x4_t Packet4us;
+typedef uint16x8_t Packet8us;
+typedef int32x2_t Packet2i;
+typedef int32x4_t Packet4i;
+typedef uint32x2_t Packet2ui;
+typedef uint32x4_t Packet4ui;
+typedef int64x2_t Packet2l;
+typedef uint64x2_t Packet2ul;
+
+#endif // EIGEN_COMP_MSVC_STRICT
+
+EIGEN_STRONG_INLINE Packet4f shuffle1(const Packet4f& m, int mask){
+ const float* a = reinterpret_cast<const float*>(&m);
+ Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3 )), *(a + ((mask >> 6) & 3))};
+ return res;
+}
+
+// fuctionally equivalent to _mm_shuffle_ps in SSE when interleave
+// == false (i.e. shuffle<false>(m, n, mask) equals _mm_shuffle_ps(m, n, mask)),
+// interleave m and n when interleave == true. Currently used in LU/arch/InverseSize4.h
+// to enable a shared implementation for fast inversion of matrices of size 4.
+template<bool interleave>
+EIGEN_STRONG_INLINE Packet4f shuffle2(const Packet4f &m, const Packet4f &n, int mask)
+{
+ const float* a = reinterpret_cast<const float*>(&m);
+ const float* b = reinterpret_cast<const float*>(&n);
+ Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(b + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))};
+ return res;
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet4f shuffle2<true>(const Packet4f &m, const Packet4f &n, int mask)
+{
+ const float* a = reinterpret_cast<const float*>(&m);
+ const float* b = reinterpret_cast<const float*>(&n);
+ Packet4f res = {*(a + (mask & 3)), *(b + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))};
+ return res;
+}
+
+EIGEN_STRONG_INLINE static int eigen_neon_shuffle_mask(int p, int q, int r, int s) {return ((s)<<6|(r)<<4|(q)<<2|(p));}
+
+EIGEN_STRONG_INLINE Packet4f vec4f_swizzle1(const Packet4f& a, int p, int q, int r, int s)
+{
+ return shuffle1(a, eigen_neon_shuffle_mask(p, q, r, s));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_swizzle2(const Packet4f& a, const Packet4f& b, int p, int q, int r, int s)
+{
+ return shuffle2<false>(a,b,eigen_neon_shuffle_mask(p, q, r, s));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<false>(a,b,eigen_neon_shuffle_mask(0, 1, 0, 1));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<false>(b,a,eigen_neon_shuffle_mask(2, 3, 2, 3));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<true>(a,b,eigen_neon_shuffle_mask(0, 0, 1, 1));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b)
+{
+ return shuffle2<true>(a,b,eigen_neon_shuffle_mask(2, 2, 3, 3));
+}
+#define vec4f_duplane(a, p) \
+ vdupq_lane_f32(vget_low_f32(a), p)
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
const Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -67,81 +162,816 @@ typedef uint32x4_t Packet4ui;
#define EIGEN_ARM_PREFETCH(ADDR)
#endif
-template<> struct packet_traits<float> : default_packet_traits
+template <>
+struct packet_traits<float> : default_packet_traits
{
typedef Packet4f type;
- typedef Packet4f half; // Packet2f intrinsics not implemented yet
- enum {
+ typedef Packet2f half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 4,
- HasHalfPacket=0, // Packet2f intrinsics not implemented yet
-
- HasDiv = 1,
- // FIXME check the Has*
- HasSin = 0,
- HasCos = 0,
- HasLog = 0,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
HasExp = 1,
- HasSqrt = 0
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
};
};
-template<> struct packet_traits<int32_t> : default_packet_traits
+
+template <>
+struct packet_traits<int8_t> : default_packet_traits
+{
+ typedef Packet16c type;
+ typedef Packet8c half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint8_t> : default_packet_traits
+{
+ typedef Packet16uc type;
+ typedef Packet8uc half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int16_t> : default_packet_traits
+{
+ typedef Packet8s type;
+ typedef Packet4s half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint16_t> : default_packet_traits
+{
+ typedef Packet8us type;
+ typedef Packet4us half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbsDiff = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int32_t> : default_packet_traits
{
typedef Packet4i type;
- typedef Packet4i half; // Packet2i intrinsics not implemented yet
- enum {
+ typedef Packet2i half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint32_t> : default_packet_traits
+{
+ typedef Packet4ui type;
+ typedef Packet2ui half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 1,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasSqrt = 1
+ };
+};
+
+template <>
+struct packet_traits<int64_t> : default_packet_traits
+{
+ typedef Packet2l type;
+ typedef Packet2l half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 2,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
+ };
+};
+
+template <>
+struct packet_traits<uint64_t> : default_packet_traits
+{
+ typedef Packet2ul type;
+ typedef Packet2ul half;
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
- HasHalfPacket=0 // Packet2i intrinsics not implemented yet
- // FIXME check the Has*
+ size = 2,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0
};
};
-#if EIGEN_GNUC_AT_MOST(4,4) && !EIGEN_COMP_LLVM
-// workaround gcc 4.2, 4.3 and 4.4 compilatin issue
+#if EIGEN_GNUC_AT_MOST(4, 4) && !EIGEN_COMP_LLVM
+// workaround gcc 4.2, 4.3 and 4.4 compilation issue
EIGEN_STRONG_INLINE float32x4_t vld1q_f32(const float* x) { return ::vld1q_f32((const float32_t*)x); }
-EIGEN_STRONG_INLINE float32x2_t vld1_f32 (const float* x) { return ::vld1_f32 ((const float32_t*)x); }
-EIGEN_STRONG_INLINE float32x2_t vld1_dup_f32 (const float* x) { return ::vld1_dup_f32 ((const float32_t*)x); }
-EIGEN_STRONG_INLINE void vst1q_f32(float* to, float32x4_t from) { ::vst1q_f32((float32_t*)to,from); }
-EIGEN_STRONG_INLINE void vst1_f32 (float* to, float32x2_t from) { ::vst1_f32 ((float32_t*)to,from); }
+EIGEN_STRONG_INLINE float32x2_t vld1_f32(const float* x) { return ::vld1_f32 ((const float32_t*)x); }
+EIGEN_STRONG_INLINE float32x2_t vld1_dup_f32(const float* x) { return ::vld1_dup_f32 ((const float32_t*)x); }
+EIGEN_STRONG_INLINE void vst1q_f32(float* to, float32x4_t from) { ::vst1q_f32((float32_t*)to,from); }
+EIGEN_STRONG_INLINE void vst1_f32 (float* to, float32x2_t from) { ::vst1_f32 ((float32_t*)to,from); }
#endif
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet4i> { typedef int32_t type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet2f>
+{
+ typedef float type;
+ typedef Packet2f half;
+ typedef Packet2i integer_packet;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4f>
+{
+ typedef float type;
+ typedef Packet2f half;
+ typedef Packet4i integer_packet;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4c>
+{
+ typedef int8_t type;
+ typedef Packet4c half;
+ enum
+ {
+ size = 4,
+ alignment = Unaligned,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8c>
+{
+ typedef int8_t type;
+ typedef Packet4c half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet16c>
+{
+ typedef int8_t type;
+ typedef Packet8c half;
+ enum
+ {
+ size = 16,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4uc>
+{
+ typedef uint8_t type;
+ typedef Packet4uc half;
+ enum
+ {
+ size = 4,
+ alignment = Unaligned,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8uc>
+{
+ typedef uint8_t type;
+ typedef Packet4uc half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet16uc>
+{
+ typedef uint8_t type;
+ typedef Packet8uc half;
+ enum
+ {
+ size = 16,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false};
+};
+template<> struct unpacket_traits<Packet4s>
+{
+ typedef int16_t type;
+ typedef Packet4s half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8s>
+{
+ typedef int16_t type;
+ typedef Packet4s half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4us>
+{
+ typedef uint16_t type;
+ typedef Packet4us half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet8us>
+{
+ typedef uint16_t type;
+ typedef Packet4us half;
+ enum
+ {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2i>
+{
+ typedef int32_t type;
+ typedef Packet2i half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4i>
+{
+ typedef int32_t type;
+ typedef Packet2i half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2ui>
+{
+ typedef uint32_t type;
+ typedef Packet2ui half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet4ui>
+{
+ typedef uint32_t type;
+ typedef Packet2ui half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2l>
+{
+ typedef int64_t type;
+ typedef Packet2l half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+template<> struct unpacket_traits<Packet2ul>
+{
+ typedef uint64_t type;
+ typedef Packet2ul half;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet2f pset1<Packet2f>(const float& from) { return vdup_n_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { return vdupq_n_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c pset1<Packet4c>(const int8_t& from)
+{ return vget_lane_s32(vreinterpret_s32_s8(vdup_n_s8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pset1<Packet8c>(const int8_t& from) { return vdup_n_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c pset1<Packet16c>(const int8_t& from) { return vdupq_n_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc pset1<Packet4uc>(const uint8_t& from)
+{ return vget_lane_u32(vreinterpret_u32_u8(vdup_n_u8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc pset1<Packet8uc>(const uint8_t& from) { return vdup_n_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc pset1<Packet16uc>(const uint8_t& from) { return vdupq_n_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s pset1<Packet4s>(const int16_t& from) { return vdup_n_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const int16_t& from) { return vdupq_n_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us pset1<Packet4us>(const uint16_t& from) { return vdup_n_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us pset1<Packet8us>(const uint16_t& from) { return vdupq_n_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i pset1<Packet2i>(const int32_t& from) { return vdup_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int32_t& from) { return vdupq_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui pset1<Packet2ui>(const uint32_t& from) { return vdup_n_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui pset1<Packet4ui>(const uint32_t& from) { return vdupq_n_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l pset1<Packet2l>(const int64_t& from) { return vdupq_n_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul pset1<Packet2ul>(const uint64_t& from) { return vdupq_n_u64(from); }
-template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { return vdupq_n_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int32_t& from) { return vdupq_n_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2f pset1frombits<Packet2f>(unsigned int from)
+{ return vreinterpret_f32_u32(vdup_n_u32(from)); }
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from)
+{ return vreinterpretq_f32_u32(vdupq_n_u32(from)); }
+template<> EIGEN_STRONG_INLINE Packet2f plset<Packet2f>(const float& a)
+{
+ const float c[] = {0.0f,1.0f};
+ return vadd_f32(pset1<Packet2f>(a), vld1_f32(c));
+}
template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a)
{
- const float f[] = {0, 1, 2, 3};
- Packet4f countdown = vld1q_f32(f);
- return vaddq_f32(pset1<Packet4f>(a), countdown);
+ const float c[] = {0.0f,1.0f,2.0f,3.0f};
+ return vaddq_f32(pset1<Packet4f>(a), vld1q_f32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4c plset<Packet4c>(const int8_t& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(vreinterpret_s8_u32(vdup_n_u32(0x03020100)), vdup_n_s8(a))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c plset<Packet8c>(const int8_t& a)
+{
+ const int8_t c[] = {0,1,2,3,4,5,6,7};
+ return vadd_s8(pset1<Packet8c>(a), vld1_s8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet16c plset<Packet16c>(const int8_t& a)
+{
+ const int8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
+ return vaddq_s8(pset1<Packet16c>(a), vld1q_s8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4uc plset<Packet4uc>(const uint8_t& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(vreinterpret_u8_u32(vdup_n_u32(0x03020100)), vdup_n_u8(a))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc plset<Packet8uc>(const uint8_t& a)
+{
+ const uint8_t c[] = {0,1,2,3,4,5,6,7};
+ return vadd_u8(pset1<Packet8uc>(a), vld1_u8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const uint8_t& a)
+{
+ const uint8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
+ return vaddq_u8(pset1<Packet16uc>(a), vld1q_u8(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4s plset<Packet4s>(const int16_t& a)
+{
+ const int16_t c[] = {0,1,2,3};
+ return vadd_s16(pset1<Packet4s>(a), vld1_s16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4us plset<Packet4us>(const uint16_t& a)
+{
+ const uint16_t c[] = {0,1,2,3};
+ return vadd_u16(pset1<Packet4us>(a), vld1_u16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet8s plset<Packet8s>(const int16_t& a)
+{
+ const int16_t c[] = {0,1,2,3,4,5,6,7};
+ return vaddq_s16(pset1<Packet8s>(a), vld1q_s16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet8us plset<Packet8us>(const uint16_t& a)
+{
+ const uint16_t c[] = {0,1,2,3,4,5,6,7};
+ return vaddq_u16(pset1<Packet8us>(a), vld1q_u16(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2i plset<Packet2i>(const int32_t& a)
+{
+ const int32_t c[] = {0,1};
+ return vadd_s32(pset1<Packet2i>(a), vld1_s32(c));
}
template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int32_t& a)
{
- const int32_t i[] = {0, 1, 2, 3};
- Packet4i countdown = vld1q_s32(i);
- return vaddq_s32(pset1<Packet4i>(a), countdown);
+ const int32_t c[] = {0,1,2,3};
+ return vaddq_s32(pset1<Packet4i>(a), vld1q_s32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2ui plset<Packet2ui>(const uint32_t& a)
+{
+ const uint32_t c[] = {0,1};
+ return vadd_u32(pset1<Packet2ui>(a), vld1_u32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet4ui plset<Packet4ui>(const uint32_t& a)
+{
+ const uint32_t c[] = {0,1,2,3};
+ return vaddq_u32(pset1<Packet4ui>(a), vld1q_u32(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2l plset<Packet2l>(const int64_t& a)
+{
+ const int64_t c[] = {0,1};
+ return vaddq_s64(pset1<Packet2l>(a), vld1q_s64(c));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul plset<Packet2ul>(const uint64_t& a)
+{
+ const uint64_t c[] = {0,1};
+ return vaddq_u64(pset1<Packet2ul>(a), vld1q_u64(c));
}
+template<> EIGEN_STRONG_INLINE Packet2f padd<Packet2f>(const Packet2f& a, const Packet2f& b) { return vadd_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return vaddq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c padd<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c padd<Packet8c>(const Packet8c& a, const Packet8c& b) { return vadd_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c>(const Packet16c& a, const Packet16c& b) { return vaddq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc padd<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc padd<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vadd_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc padd<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vaddq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s padd<Packet4s>(const Packet4s& a, const Packet4s& b) { return vadd_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s>(const Packet8s& a, const Packet8s& b) { return vaddq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us padd<Packet4us>(const Packet4us& a, const Packet4us& b) { return vadd_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us>(const Packet8us& a, const Packet8us& b) { return vaddq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i padd<Packet2i>(const Packet2i& a, const Packet2i& b) { return vadd_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return vaddq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui padd<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vadd_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vaddq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l padd<Packet2l>(const Packet2l& a, const Packet2l& b) { return vaddq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul padd<Packet2ul>(const Packet2ul& a, const Packet2ul& b) { return vaddq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f psub<Packet2f>(const Packet2f& a, const Packet2f& b) { return vsub_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return vsubq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c psub<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vsub_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c psub<Packet8c>(const Packet8c& a, const Packet8c& b) { return vsub_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c psub<Packet16c>(const Packet16c& a, const Packet16c& b) { return vsubq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc psub<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vsub_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc psub<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vsub_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc psub<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vsubq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s psub<Packet4s>(const Packet4s& a, const Packet4s& b) { return vsub_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s psub<Packet8s>(const Packet8s& a, const Packet8s& b) { return vsubq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us psub<Packet4us>(const Packet4us& a, const Packet4us& b) { return vsub_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us psub<Packet8us>(const Packet8us& a, const Packet8us& b) { return vsubq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i psub<Packet2i>(const Packet2i& a, const Packet2i& b) { return vsub_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return vsubq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui psub<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vsub_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui psub<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vsubq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l psub<Packet2l>(const Packet2l& a, const Packet2l& b) { return vsubq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul psub<Packet2ul>(const Packet2ul& a, const Packet2ul& b) { return vsubq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pxor<Packet2f>(const Packet2f& a, const Packet2f& b);
+template<> EIGEN_STRONG_INLINE Packet2f paddsub<Packet2f>(const Packet2f& a, const Packet2f & b) {
+ Packet2f mask = {numext::bit_cast<float>(0x80000000u), 0.0f};
+ return padd(a, pxor(mask, b));
+}
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b);
+template<> EIGEN_STRONG_INLINE Packet4f paddsub<Packet4f>(const Packet4f& a, const Packet4f& b) {
+ Packet4f mask = {numext::bit_cast<float>(0x80000000u), 0.0f, numext::bit_cast<float>(0x80000000u), 0.0f};
+ return padd(a, pxor(mask, b));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pnegate(const Packet2f& a) { return vneg_f32(a); }
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return vnegq_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4c pnegate(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vneg_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pnegate(const Packet8c& a) { return vneg_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a) { return vnegq_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet4s pnegate(const Packet4s& a) { return vneg_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) { return vnegq_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet2i pnegate(const Packet2i& a) { return vneg_s32(a); }
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return vnegq_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet2l pnegate(const Packet2l& a) {
+#if EIGEN_ARCH_ARM64
+ return vnegq_s64(a);
+#else
+ return vcombine_s64(
+ vdup_n_s64(-vgetq_lane_s64(a, 0)),
+ vdup_n_s64(-vgetq_lane_s64(a, 1)));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2f pconj(const Packet2f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4c pconj(const Packet4c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8c pconj(const Packet8c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16c pconj(const Packet16c& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4uc pconj(const Packet4uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8uc pconj(const Packet8uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16uc pconj(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4s pconj(const Packet4s& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8s pconj(const Packet8s& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4us pconj(const Packet4us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8us pconj(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2i pconj(const Packet2i& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2ui pconj(const Packet2ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4ui pconj(const Packet4ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2l pconj(const Packet2l& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2ul pconj(const Packet2ul& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2f pmul<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmul_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) { return vmulq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pmul<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmul_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmul<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmul_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c>(const Packet16c& a, const Packet16c& b) { return vmulq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmul<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmul_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmul<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmul_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vmulq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmul<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmul_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmul<Packet8s>(const Packet8s& a, const Packet8s& b) { return vmulq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmul<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmul_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmul<Packet8us>(const Packet8us& a, const Packet8us& b) { return vmulq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmul<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmul_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return vmulq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmul<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmul_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmul<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vmulq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmul<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0)*vgetq_lane_s64(b, 0)),
+ vdup_n_s64(vgetq_lane_s64(a, 1)*vgetq_lane_s64(b, 1)));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmul<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0)*vgetq_lane_u64(b, 0)),
+ vdup_n_u64(vgetq_lane_u64(a, 1)*vgetq_lane_u64(b, 1)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pdiv<Packet2f>(const Packet2f& a, const Packet2f& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vdiv_f32(a,b);
+#else
+ Packet2f inv, restep, div;
+ // NEON does not offer a divide instruction, we have to do a reciprocal approximation
+ // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers
+ // a reciprocal estimate AND a reciprocal step -which saves a few instructions
+ // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with
+ // Newton-Raphson and vrecpsq_f32()
+ inv = vrecpe_f32(b);
+
+ // This returns a differential, by which we will have to multiply inv to get a better
+ // approximation of 1/b.
+ restep = vrecps_f32(b, inv);
+ inv = vmul_f32(restep, inv);
+
+ // Finally, multiply a by 1/b and get the wanted result of the division.
+ div = vmul_f32(a, inv);
+
+ return div;
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
#if EIGEN_ARCH_ARM64
@@ -168,357 +998,2629 @@ template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4c pdiv<Packet4c>(const Packet4c& /*a*/, const Packet4c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pdiv<Packet8c>(const Packet8c& /*a*/, const Packet8c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet16c pdiv<Packet16c>(const Packet16c& /*a*/, const Packet16c& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet16c>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc pdiv<Packet4uc>(const Packet4uc& /*a*/, const Packet4uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pdiv<Packet8uc>(const Packet8uc& /*a*/, const Packet8uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc pdiv<Packet16uc>(const Packet16uc& /*a*/, const Packet16uc& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet16uc>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4s pdiv<Packet4s>(const Packet4s& /*a*/, const Packet4s& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4s>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8s pdiv<Packet8s>(const Packet8s& /*a*/, const Packet8s& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8s>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4us pdiv<Packet4us>(const Packet4us& /*a*/, const Packet4us& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4us>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet8us pdiv<Packet8us>(const Packet8us& /*a*/, const Packet8us& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet8us>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet2i pdiv<Packet2i>(const Packet2i& /*a*/, const Packet2i& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2i>(0);
+}
template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& /*a*/, const Packet4i& /*b*/)
-{ eigen_assert(false && "packet integer division are not supported by NEON");
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
return pset1<Packet4i>(0);
}
+template<> EIGEN_STRONG_INLINE Packet2ui pdiv<Packet2ui>(const Packet2ui& /*a*/, const Packet2ui& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2ui>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet4ui pdiv<Packet4ui>(const Packet4ui& /*a*/, const Packet4ui& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet4ui>(0);
+}
+template<> EIGEN_STRONG_INLINE Packet2l pdiv<Packet2l>(const Packet2l& /*a*/, const Packet2l& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2l>(0LL);
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pdiv<Packet2ul>(const Packet2ul& /*a*/, const Packet2ul& /*b*/)
+{
+ eigen_assert(false && "packet integer division are not supported by NEON");
+ return pset1<Packet2ul>(0ULL);
+}
-// Clang/ARM wrongly advertises __ARM_FEATURE_FMA even when it's not available,
-// then implements a slow software scalar fallback calling fmaf()!
-// Filed LLVM bug:
-// https://llvm.org/bugs/show_bug.cgi?id=27216
-#if (defined __ARM_FEATURE_FMA) && !(EIGEN_COMP_CLANG && EIGEN_ARCH_ARM)
-// See bug 936.
-// FMA is available on VFPv4 i.e. when compiling with -mfpu=neon-vfpv4.
-// FMA is a true fused multiply-add i.e. only 1 rounding at the end, no intermediate rounding.
-// MLA is not fused i.e. does 2 roundings.
-// In addition to giving better accuracy, FMA also gives better performance here on a Krait (Nexus 4):
-// MLA: 10 GFlop/s ; FMA: 12 GFlops/s.
-template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vfmaq_f32(c,a,b); }
-#else
-template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
-#if EIGEN_COMP_CLANG && EIGEN_ARCH_ARM
- // Clang/ARM will replace VMLA by VMUL+VADD at least for some values of -mcpu,
- // at least -mcpu=cortex-a8 and -mcpu=cortex-a7. Since the former is the default on
- // -march=armv7-a, that is a very common case.
- // See e.g. this thread:
- // http://lists.llvm.org/pipermail/llvm-dev/2013-December/068806.html
- // Filed LLVM bug:
- // https://llvm.org/bugs/show_bug.cgi?id=27219
- Packet4f r = c;
- asm volatile(
- "vmla.f32 %q[r], %q[a], %q[b]"
- : [r] "+w" (r)
- : [a] "w" (a),
- [b] "w" (b)
- : );
- return r;
+
+#ifdef __ARM_FEATURE_FMA
+template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+{ return vfmaq_f32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c)
+{ return vfma_f32(c,a,b); }
#else
+template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+{
return vmlaq_f32(c,a,b);
-#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c)
+{
+ return vmla_f32(c,a,b);
}
#endif
// No FMA instruction for int, so use MLA unconditionally.
-template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return vmlaq_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pmadd(const Packet4c& a, const Packet4c& b, const Packet4c& c)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmla_s8(
+ vreinterpret_s8_s32(vdup_n_s32(c)),
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmadd(const Packet8c& a, const Packet8c& b, const Packet8c& c)
+{ return vmla_s8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmadd(const Packet16c& a, const Packet16c& b, const Packet16c& c)
+{ return vmlaq_s8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmadd(const Packet4uc& a, const Packet4uc& b, const Packet4uc& c)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmla_u8(
+ vreinterpret_u8_u32(vdup_n_u32(c)),
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmadd(const Packet8uc& a, const Packet8uc& b, const Packet8uc& c)
+{ return vmla_u8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmadd(const Packet16uc& a, const Packet16uc& b, const Packet16uc& c)
+{ return vmlaq_u8(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmadd(const Packet4s& a, const Packet4s& b, const Packet4s& c)
+{ return vmla_s16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c)
+{ return vmlaq_s16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmadd(const Packet4us& a, const Packet4us& b, const Packet4us& c)
+{ return vmla_u16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c)
+{ return vmlaq_u16(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmadd(const Packet2i& a, const Packet2i& b, const Packet2i& c)
+{ return vmla_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c)
+{ return vmlaq_s32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmadd(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c)
+{ return vmla_u32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmadd(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c)
+{ return vmlaq_u32(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pabsdiff<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vabd_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pabsdiff<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vabdq_f32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4c pabsdiff<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vabd_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pabsdiff<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vabd_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pabsdiff<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vabdq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pabsdiff<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vabd_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pabsdiff<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vabd_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pabsdiff<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vabdq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pabsdiff<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vabd_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pabsdiff<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vabdq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pabsdiff<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vabd_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pabsdiff<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vabdq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pabsdiff<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vabd_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pabsdiff<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vabdq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pabsdiff<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vabd_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pabsdiff<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vabdq_u32(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmin<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmin_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return vminq_f32(a,b); }
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4f pmin<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) { return vminnmq_f32(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmin<PropagateNumbers, Packet2f>(const Packet2f& a, const Packet2f& b) { return vminnm_f32(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4f pmin<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) { return pmin<Packet4f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmin<PropagateNaN, Packet2f>(const Packet2f& a, const Packet2f& b) { return pmin<Packet2f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4c pmin<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmin_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmin<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmin_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vminq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmin<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmin_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmin<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmin_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vminq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmin<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmin_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmin<Packet8s>(const Packet8s& a, const Packet8s& b) { return vminq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmin<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmin_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, const Packet8us& b) { return vminq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmin<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmin_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vminq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmin<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmin_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmin<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vminq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmin<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64((std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))),
+ vdup_n_s64((std::min)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1))));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmin<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64((std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))),
+ vdup_n_u64((std::min)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1))));
+}
+template<> EIGEN_STRONG_INLINE Packet2f pmax<Packet2f>(const Packet2f& a, const Packet2f& b) { return vmax_f32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return vmaxq_f32(a,b); }
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4f pmax<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) { return vmaxnmq_f32(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2f pmax<PropagateNumbers, Packet2f>(const Packet2f& a, const Packet2f& b) { return vmaxnm_f32(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4f pmax<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) { return pmax<Packet4f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pmax<PropagateNaN, Packet2f>(const Packet2f& a, const Packet2f& b) { return pmax<Packet2f>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4c pmax<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vmax_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pmax<Packet8c>(const Packet8c& a, const Packet8c& b) { return vmax_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmax<Packet16c>(const Packet16c& a, const Packet16c& b) { return vmaxq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pmax<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vmax_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pmax<Packet8uc>(const Packet8uc& a, const Packet8uc& b) { return vmax_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vmaxq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pmax<Packet4s>(const Packet4s& a, const Packet4s& b) { return vmax_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmax<Packet8s>(const Packet8s& a, const Packet8s& b) { return vmaxq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pmax<Packet4us>(const Packet4us& a, const Packet4us& b) { return vmax_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmax<Packet8us>(const Packet8us& a, const Packet8us& b) { return vmaxq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pmax<Packet2i>(const Packet2i& a, const Packet2i& b) { return vmax_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vmaxq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pmax<Packet2ui>(const Packet2ui& a, const Packet2ui& b) { return vmax_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmax<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vmaxq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pmax<Packet2l>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s64(
+ vdup_n_s64((std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))),
+ vdup_n_s64((std::max)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1))));
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pmax<Packet2ul>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u64(
+ vdup_n_u64((std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))),
+ vdup_n_u64((std::max)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1))));
+}
-// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
-template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b)
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_le<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vcle_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vcleq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_le<Packet4c>(const Packet4c& a, const Packet4c& b)
{
- return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
+ return vget_lane_s32(vreinterpret_s32_u8(vcle_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_le<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vcle_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_le<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vcleq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_le<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vcle_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_le<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vcle_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vcleq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_le<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vcle_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_le<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vcleq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_le<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vcle_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_le<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vcleq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_le<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vcle_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vcleq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_le<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vcle_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_le<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vcleq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_le<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vcleq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) <= vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) <= vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_le<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vcleq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) <= vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) <= vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
}
-template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vandq_s32(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b)
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vclt_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vcltq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_lt<Packet4c>(const Packet4c& a, const Packet4c& b)
{
- return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
+ return vget_lane_s32(vreinterpret_s32_u8(vclt_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_lt<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vclt_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vcltq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_lt<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vclt_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_lt<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vclt_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vcltq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_lt<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vclt_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vcltq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_lt<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vclt_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vcltq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_lt<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vclt_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vcltq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_lt<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vclt_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_lt<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vcltq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_lt<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vcltq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) < vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) < vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_lt<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vcltq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) < vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) < vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
}
-template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vorrq_s32(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b)
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_eq<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vceq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vceqq_f32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4c pcmp_eq<Packet4c>(const Packet4c& a, const Packet4c& b)
+{
+ return vget_lane_s32(vreinterpret_s32_u8(vceq_s8(
+ vreinterpret_s8_s32(vdup_n_s32(a)),
+ vreinterpret_s8_s32(vdup_n_s32(b)))), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c pcmp_eq<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vreinterpret_s8_u8(vceq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vreinterpretq_s8_u8(vceqq_s8(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4uc pcmp_eq<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
{
- return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
+ return vget_lane_u32(vreinterpret_u32_u8(vceq_u8(
+ vreinterpret_u8_u32(vdup_n_u32(a)),
+ vreinterpret_u8_u32(vdup_n_u32(b)))), 0);
}
+template<> EIGEN_STRONG_INLINE Packet8uc pcmp_eq<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vceq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vceqq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pcmp_eq<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vreinterpret_s16_u16(vceq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vreinterpretq_s16_u16(vceqq_s16(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4us pcmp_eq<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vceq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vceqq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pcmp_eq<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vreinterpret_s32_u32(vceq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vreinterpretq_s32_u32(vceqq_s32(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2ui pcmp_eq<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vceq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_eq<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vceqq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pcmp_eq<Packet2l>(const Packet2l& a, const Packet2l& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vreinterpretq_s64_u64(vceqq_s64(a,b));
+#else
+ return vcombine_s64(
+ vdup_n_s64(vgetq_lane_s64(a, 0) == vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0),
+ vdup_n_s64(vgetq_lane_s64(a, 1) == vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pcmp_eq<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{
+#if EIGEN_ARCH_ARM64
+ return vceqq_u64(a,b);
+#else
+ return vcombine_u64(
+ vdup_n_u64(vgetq_lane_u64(a, 0) == vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0),
+ vdup_n_u64(vgetq_lane_u64(a, 1) == vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt_or_nan<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vmvn_u32(vcge_f32(a,b))); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vmvnq_u32(vcgeq_f32(a,b))); }
+
+// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
+template<> EIGEN_STRONG_INLINE Packet2f pand<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pand<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a & b; }
+template<> EIGEN_STRONG_INLINE Packet8c pand<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return vand_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pand<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vandq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pand<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a & b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pand<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vand_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pand<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vandq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pand<Packet4s>(const Packet4s& a, const Packet4s& b) { return vand_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pand<Packet8s>(const Packet8s& a, const Packet8s& b) { return vandq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pand<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vand_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vandq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pand<Packet2i>(const Packet2i& a, const Packet2i& b) { return vand_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vandq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pand<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vand_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vandq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pand<Packet2l>(const Packet2l& a, const Packet2l& b) { return vandq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pand<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vandq_u64(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f por<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c por<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a | b; }
+template<> EIGEN_STRONG_INLINE Packet8c por<Packet8c>(const Packet8c& a, const Packet8c& b) { return vorr_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c por<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return vorrq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc por<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a | b; }
+template<> EIGEN_STRONG_INLINE Packet8uc por<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vorr_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc por<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vorrq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s por<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vorr_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vorrq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us por<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vorr_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vorrq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i por<Packet2i>(const Packet2i& a, const Packet2i& b) { return vorr_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vorrq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui por<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vorr_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui por<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vorrq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l por<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return vorrq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul por<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vorrq_u64(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pxor<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pxor<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a ^ b; }
+template<> EIGEN_STRONG_INLINE Packet8c pxor<Packet8c>(const Packet8c& a, const Packet8c& b)
+{ return veor_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pxor<Packet16c>(const Packet16c& a, const Packet16c& b)
+{ return veorq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pxor<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a ^ b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pxor<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return veor_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pxor<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return veorq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pxor<Packet4s>(const Packet4s& a, const Packet4s& b) { return veor_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pxor<Packet8s>(const Packet8s& a, const Packet8s& b) { return veorq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pxor<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return veor_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pxor<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return veorq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pxor<Packet2i>(const Packet2i& a, const Packet2i& b) { return veor_s32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return veorq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pxor<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return veor_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pxor<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return veorq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pxor<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return veorq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pxor<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return veorq_u64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2f pandnot<Packet2f>(const Packet2f& a, const Packet2f& b)
+{ return vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b)
+{ return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); }
+template<> EIGEN_STRONG_INLINE Packet4c pandnot<Packet4c>(const Packet4c& a, const Packet4c& b)
+{ return a & ~b; }
+template<> EIGEN_STRONG_INLINE Packet8c pandnot<Packet8c>(const Packet8c& a, const Packet8c& b) { return vbic_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pandnot<Packet16c>(const Packet16c& a, const Packet16c& b) { return vbicq_s8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4uc pandnot<Packet4uc>(const Packet4uc& a, const Packet4uc& b)
+{ return a & ~b; }
+template<> EIGEN_STRONG_INLINE Packet8uc pandnot<Packet8uc>(const Packet8uc& a, const Packet8uc& b)
+{ return vbic_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pandnot<Packet16uc>(const Packet16uc& a, const Packet16uc& b)
+{ return vbicq_u8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4s pandnot<Packet4s>(const Packet4s& a, const Packet4s& b)
+{ return vbic_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8s pandnot<Packet8s>(const Packet8s& a, const Packet8s& b)
+{ return vbicq_s16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4us pandnot<Packet4us>(const Packet4us& a, const Packet4us& b)
+{ return vbic_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pandnot<Packet8us>(const Packet8us& a, const Packet8us& b)
+{ return vbicq_u16(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2i pandnot<Packet2i>(const Packet2i& a, const Packet2i& b)
+{ return vbic_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b)
+{ return vbicq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ui pandnot<Packet2ui>(const Packet2ui& a, const Packet2ui& b)
+{ return vbic_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pandnot<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{ return vbicq_u32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2l pandnot<Packet2l>(const Packet2l& a, const Packet2l& b)
+{ return vbicq_s64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2ul pandnot<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
+{ return vbicq_u64(a,b); }
+
+
+template<int N> EIGEN_STRONG_INLINE Packet4c parithmetic_shift_right(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vshr_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c parithmetic_shift_right(Packet8c a) { return vshr_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16c parithmetic_shift_right(Packet16c a) { return vshrq_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc parithmetic_shift_right(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vshr_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc parithmetic_shift_right(Packet8uc a) { return vshr_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc parithmetic_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s parithmetic_shift_right(Packet4s a) { return vshr_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) { return vshrq_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4us parithmetic_shift_right(Packet4us a) { return vshr_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us parithmetic_shift_right(Packet8us a) { return vshrq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i parithmetic_shift_right(Packet2i a) { return vshr_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) { return vshrq_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui parithmetic_shift_right(Packet2ui a) { return vshr_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui parithmetic_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l parithmetic_shift_right(Packet2l a) { return vshrq_n_s64(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul parithmetic_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); }
+
+template<int N> EIGEN_STRONG_INLINE Packet4c plogical_shift_right(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_u8(vshr_n_u8(vreinterpret_u8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c plogical_shift_right(Packet8c a)
+{ return vreinterpret_s8_u8(vshr_n_u8(vreinterpret_u8_s8(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet16c plogical_shift_right(Packet16c a)
+{ return vreinterpretq_s8_u8(vshrq_n_u8(vreinterpretq_u8_s8(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc plogical_shift_right(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_s8(vshr_n_s8(vreinterpret_s8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc plogical_shift_right(Packet8uc a) { return vshr_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc plogical_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s plogical_shift_right(Packet4s a)
+{ return vreinterpret_s16_u16(vshr_n_u16(vreinterpret_u16_s16(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a)
+{ return vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_s16(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4us plogical_shift_right(Packet4us a) { return vshr_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(Packet8us a) { return vshrq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i plogical_shift_right(Packet2i a)
+{ return vreinterpret_s32_u32(vshr_n_u32(vreinterpret_u32_s32(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
+{ return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui plogical_shift_right(Packet2ui a) { return vshr_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l plogical_shift_right(Packet2l a)
+{ return vreinterpretq_s64_u64(vshrq_n_u64(vreinterpretq_u64_s64(a),N)); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul plogical_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); }
+
+template<int N> EIGEN_STRONG_INLINE Packet4c plogical_shift_left(Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vshl_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8c plogical_shift_left(Packet8c a) { return vshl_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16c plogical_shift_left(Packet16c a) { return vshlq_n_s8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4uc plogical_shift_left(Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vshl_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); }
+template<int N> EIGEN_STRONG_INLINE Packet8uc plogical_shift_left(Packet8uc a) { return vshl_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet16uc plogical_shift_left(Packet16uc a) { return vshlq_n_u8(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4s plogical_shift_left(Packet4s a) { return vshl_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) { return vshlq_n_s16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4us plogical_shift_left(Packet4us a) { return vshl_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) { return vshlq_n_u16(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2i plogical_shift_left(Packet2i a) { return vshl_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) { return vshlq_n_s32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ui plogical_shift_left(Packet2ui a) { return vshl_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) { return vshlq_n_u32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2l plogical_shift_left(Packet2l a) { return vshlq_n_s64(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet2ul plogical_shift_left(Packet2ul a) { return vshlq_n_u64(a,N); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pload<Packet2f>(const float* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c pload<Packet4c>(const int8_t* from)
+{
+ Packet4c res;
+ memcpy(&res, from, sizeof(Packet4c));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8c pload<Packet8c>(const int8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c pload<Packet16c>(const int8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc pload<Packet4uc>(const uint8_t* from)
{
- return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b)));
+ Packet4uc res;
+ memcpy(&res, from, sizeof(Packet4uc));
+ return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vbicq_s32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8uc pload<Packet8uc>(const uint8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s pload<Packet4s>(const int16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s pload<Packet8s>(const int16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us pload<Packet4us>(const uint16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us pload<Packet8us>(const uint16_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i pload<Packet2i>(const int32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui pload<Packet2ui>(const uint32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui pload<Packet4ui>(const uint32_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l pload<Packet2l>(const int64_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul pload<Packet2ul>(const uint64_t* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u64(from); }
-template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from); }
-
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f32(from); }
-template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int32_t* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2f ploadu<Packet2f>(const float* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c ploadu<Packet4c>(const int8_t* from)
+{
+ Packet4c res;
+ memcpy(&res, from, sizeof(Packet4c));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8c ploadu<Packet8c>(const int8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const int8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s8(from); }
+template<> EIGEN_STRONG_INLINE Packet4uc ploadu<Packet4uc>(const uint8_t* from)
+{
+ Packet4uc res;
+ memcpy(&res, from, sizeof(Packet4uc));
+ return res;
+}
+template<> EIGEN_STRONG_INLINE Packet8uc ploadu<Packet8uc>(const uint8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet16uc ploadu<Packet16uc>(const uint8_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u8(from); }
+template<> EIGEN_STRONG_INLINE Packet4s ploadu<Packet4s>(const int16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet8s ploadu<Packet8s>(const int16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s16(from); }
+template<> EIGEN_STRONG_INLINE Packet4us ploadu<Packet4us>(const uint16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const uint16_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u16(from); }
+template<> EIGEN_STRONG_INLINE Packet2i ploadu<Packet2i>(const int32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet2ui ploadu<Packet2ui>(const uint32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploadu<Packet4ui>(const uint32_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet2l ploadu<Packet2l>(const int64_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul ploadu<Packet2ul>(const uint64_t* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u64(from); }
+template<> EIGEN_STRONG_INLINE Packet2f ploaddup<Packet2f>(const float* from)
+{ return vld1_dup_f32(from); }
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+{ return vcombine_f32(vld1_dup_f32(from), vld1_dup_f32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet4c ploaddup<Packet4c>(const int8_t* from)
+{
+ const int8x8_t a = vreinterpret_s8_s32(vdup_n_s32(pload<Packet4c>(from)));
+ return vget_lane_s32(vreinterpret_s32_s8(vzip_s8(a,a).val[0]), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8c ploaddup<Packet8c>(const int8_t* from)
+{
+ const int8x8_t a = vld1_s8(from);
+ return vzip_s8(a,a).val[0];
+}
+template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const int8_t* from)
+{
+ const int8x8_t a = vld1_s8(from);
+ const int8x8x2_t b = vzip_s8(a,a);
+ return vcombine_s8(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc ploaddup<Packet4uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vreinterpret_u8_u32(vdup_n_u32(pload<Packet4uc>(from)));
+ return vget_lane_u32(vreinterpret_u32_u8(vzip_u8(a,a).val[0]), 0);
+}
+template<> EIGEN_STRONG_INLINE Packet8uc ploaddup<Packet8uc>(const uint8_t* from)
{
- float32x2_t lo, hi;
- lo = vld1_dup_f32(from);
- hi = vld1_dup_f32(from+1);
- return vcombine_f32(lo, hi);
+ const uint8x8_t a = vld1_u8(from);
+ return vzip_u8(a,a).val[0];
}
+template<> EIGEN_STRONG_INLINE Packet16uc ploaddup<Packet16uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vld1_u8(from);
+ const uint8x8x2_t b = vzip_u8(a,a);
+ return vcombine_u8(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4s ploaddup<Packet4s>(const int16_t* from)
+{
+ return vreinterpret_s16_u32(vzip_u32(vreinterpret_u32_s16(vld1_dup_s16(from)),
+ vreinterpret_u32_s16(vld1_dup_s16(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet8s ploaddup<Packet8s>(const int16_t* from)
+{
+ const int16x4_t a = vld1_s16(from);
+ const int16x4x2_t b = vzip_s16(a,a);
+ return vcombine_s16(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet4us ploaddup<Packet4us>(const uint16_t* from)
+{
+ return vreinterpret_u16_u32(vzip_u32(vreinterpret_u32_u16(vld1_dup_u16(from)),
+ vreinterpret_u32_u16(vld1_dup_u16(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet8us ploaddup<Packet8us>(const uint16_t* from)
+{
+ const uint16x4_t a = vld1_u16(from);
+ const uint16x4x2_t b = vzip_u16(a,a);
+ return vcombine_u16(b.val[0], b.val[1]);
+}
+template<> EIGEN_STRONG_INLINE Packet2i ploaddup<Packet2i>(const int32_t* from)
+{ return vld1_dup_s32(from); }
template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int32_t* from)
+{ return vcombine_s32(vld1_dup_s32(from), vld1_dup_s32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet2ui ploaddup<Packet2ui>(const uint32_t* from)
+{ return vld1_dup_u32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploaddup<Packet4ui>(const uint32_t* from)
+{ return vcombine_u32(vld1_dup_u32(from), vld1_dup_u32(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet2l ploaddup<Packet2l>(const int64_t* from)
+{ return vld1q_dup_s64(from); }
+template<> EIGEN_STRONG_INLINE Packet2ul ploaddup<Packet2ul>(const uint64_t* from)
+{ return vld1q_dup_u64(from); }
+
+template<> EIGEN_STRONG_INLINE Packet4f ploadquad<Packet4f>(const float* from) { return vld1q_dup_f32(from); }
+template<> EIGEN_STRONG_INLINE Packet4c ploadquad<Packet4c>(const int8_t* from)
+{ return vget_lane_s32(vreinterpret_s32_s8(vld1_dup_s8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c ploadquad<Packet8c>(const int8_t* from)
{
- int32x2_t lo, hi;
- lo = vld1_dup_s32(from);
- hi = vld1_dup_s32(from+1);
- return vcombine_s32(lo, hi);
+ return vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]);
}
+template<> EIGEN_STRONG_INLINE Packet16c ploadquad<Packet16c>(const int8_t* from)
+{
+ const int8x8_t a = vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]);
+ const int8x8_t b = vreinterpret_s8_u32(vzip_u32(
+ vreinterpret_u32_s8(vld1_dup_s8(from+2)),
+ vreinterpret_u32_s8(vld1_dup_s8(from+3))).val[0]);
+ return vcombine_s8(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet4uc ploadquad<Packet4uc>(const uint8_t* from)
+{ return vget_lane_u32(vreinterpret_u32_u8(vld1_dup_u8(from)), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc ploadquad<Packet8uc>(const uint8_t* from)
+{
+ return vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc ploadquad<Packet16uc>(const uint8_t* from)
+{
+ const uint8x8_t a = vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]);
+ const uint8x8_t b = vreinterpret_u8_u32(vzip_u32(
+ vreinterpret_u32_u8(vld1_dup_u8(from+2)),
+ vreinterpret_u32_u8(vld1_dup_u8(from+3))).val[0]);
+ return vcombine_u8(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8s ploadquad<Packet8s>(const int16_t* from)
+{ return vcombine_s16(vld1_dup_s16(from), vld1_dup_s16(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const uint16_t* from)
+{ return vcombine_u16(vld1_dup_u16(from), vld1_dup_u16(from+1)); }
+template<> EIGEN_STRONG_INLINE Packet4i ploadquad<Packet4i>(const int32_t* from) { return vld1q_dup_s32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui ploadquad<Packet4ui>(const uint32_t* from) { return vld1q_dup_u32(from); }
-template<> EIGEN_STRONG_INLINE void pstore<float> (float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to, from); }
-template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to, from); }
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet2f& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet4c& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet8c& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet4uc& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet8uc& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet16uc& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet4s& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet8s& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet4us& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet8us& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet2i& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet2ui& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet4ui& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet2l& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s64(to,from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint64_t>(uint64_t* to, const Packet2ul& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u64(to,from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<float> (float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_f32(to, from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_s32(to, from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet2f& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet4c& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet8c& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int8_t>(int8_t* to, const Packet16c& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet4uc& from)
+{ memcpy(to, &from, sizeof(from)); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet8uc& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint8_t>(uint8_t* to, const Packet16uc& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u8(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int16_t>(int16_t* to, const Packet4s& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int16_t>(int16_t* to, const Packet8s& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint16_t>(uint16_t* to, const Packet4us& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint16_t>(uint16_t* to, const Packet8us& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u16(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet2i& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int32_t>(int32_t* to, const Packet4i& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet2ui& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet4ui& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u32(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<int64_t>(int64_t* to, const Packet2l& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s64(to,from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint64_t>(uint64_t* to, const Packet2ul& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u64(to,from); }
-template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pgather<float, Packet2f>(const float* from, Index stride)
{
- Packet4f res = pset1<Packet4f>(0.f);
- res = vsetq_lane_f32(from[0*stride], res, 0);
- res = vsetq_lane_f32(from[1*stride], res, 1);
- res = vsetq_lane_f32(from[2*stride], res, 2);
- res = vsetq_lane_f32(from[3*stride], res, 3);
+ Packet2f res = vld1_dup_f32(from);
+ res = vld1_lane_f32(from + 1*stride, res, 1);
return res;
}
-template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int32_t, Packet4i>(const int32_t* from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
- Packet4i res = pset1<Packet4i>(0);
- res = vsetq_lane_s32(from[0*stride], res, 0);
- res = vsetq_lane_s32(from[1*stride], res, 1);
- res = vsetq_lane_s32(from[2*stride], res, 2);
- res = vsetq_lane_s32(from[3*stride], res, 3);
+ Packet4f res = vld1q_dup_f32(from);
+ res = vld1q_lane_f32(from + 1*stride, res, 1);
+ res = vld1q_lane_f32(from + 2*stride, res, 2);
+ res = vld1q_lane_f32(from + 3*stride, res, 3);
return res;
}
-
-template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c pgather<int8_t, Packet4c>(const int8_t* from, Index stride)
{
- to[stride*0] = vgetq_lane_f32(from, 0);
- to[stride*1] = vgetq_lane_f32(from, 1);
- to[stride*2] = vgetq_lane_f32(from, 2);
- to[stride*3] = vgetq_lane_f32(from, 3);
+ Packet4c res;
+ for (int i = 0; i != 4; i++)
+ reinterpret_cast<int8_t*>(&res)[i] = *(from + i * stride);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pgather<int8_t, Packet8c>(const int8_t* from, Index stride)
+{
+ Packet8c res = vld1_dup_s8(from);
+ res = vld1_lane_s8(from + 1*stride, res, 1);
+ res = vld1_lane_s8(from + 2*stride, res, 2);
+ res = vld1_lane_s8(from + 3*stride, res, 3);
+ res = vld1_lane_s8(from + 4*stride, res, 4);
+ res = vld1_lane_s8(from + 5*stride, res, 5);
+ res = vld1_lane_s8(from + 6*stride, res, 6);
+ res = vld1_lane_s8(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pgather<int8_t, Packet16c>(const int8_t* from, Index stride)
+{
+ Packet16c res = vld1q_dup_s8(from);
+ res = vld1q_lane_s8(from + 1*stride, res, 1);
+ res = vld1q_lane_s8(from + 2*stride, res, 2);
+ res = vld1q_lane_s8(from + 3*stride, res, 3);
+ res = vld1q_lane_s8(from + 4*stride, res, 4);
+ res = vld1q_lane_s8(from + 5*stride, res, 5);
+ res = vld1q_lane_s8(from + 6*stride, res, 6);
+ res = vld1q_lane_s8(from + 7*stride, res, 7);
+ res = vld1q_lane_s8(from + 8*stride, res, 8);
+ res = vld1q_lane_s8(from + 9*stride, res, 9);
+ res = vld1q_lane_s8(from + 10*stride, res, 10);
+ res = vld1q_lane_s8(from + 11*stride, res, 11);
+ res = vld1q_lane_s8(from + 12*stride, res, 12);
+ res = vld1q_lane_s8(from + 13*stride, res, 13);
+ res = vld1q_lane_s8(from + 14*stride, res, 14);
+ res = vld1q_lane_s8(from + 15*stride, res, 15);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc pgather<uint8_t, Packet4uc>(const uint8_t* from, Index stride)
+{
+ Packet4uc res;
+ for (int i = 0; i != 4; i++)
+ reinterpret_cast<uint8_t*>(&res)[i] = *(from + i * stride);
+ return res;
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<int32_t, Packet4i>(int32_t* to, const Packet4i& from, Index stride)
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pgather<uint8_t, Packet8uc>(const uint8_t* from, Index stride)
{
- to[stride*0] = vgetq_lane_s32(from, 0);
- to[stride*1] = vgetq_lane_s32(from, 1);
- to[stride*2] = vgetq_lane_s32(from, 2);
- to[stride*3] = vgetq_lane_s32(from, 3);
+ Packet8uc res = vld1_dup_u8(from);
+ res = vld1_lane_u8(from + 1*stride, res, 1);
+ res = vld1_lane_u8(from + 2*stride, res, 2);
+ res = vld1_lane_u8(from + 3*stride, res, 3);
+ res = vld1_lane_u8(from + 4*stride, res, 4);
+ res = vld1_lane_u8(from + 5*stride, res, 5);
+ res = vld1_lane_u8(from + 6*stride, res, 6);
+ res = vld1_lane_u8(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pgather<uint8_t, Packet16uc>(const uint8_t* from, Index stride)
+{
+ Packet16uc res = vld1q_dup_u8(from);
+ res = vld1q_lane_u8(from + 1*stride, res, 1);
+ res = vld1q_lane_u8(from + 2*stride, res, 2);
+ res = vld1q_lane_u8(from + 3*stride, res, 3);
+ res = vld1q_lane_u8(from + 4*stride, res, 4);
+ res = vld1q_lane_u8(from + 5*stride, res, 5);
+ res = vld1q_lane_u8(from + 6*stride, res, 6);
+ res = vld1q_lane_u8(from + 7*stride, res, 7);
+ res = vld1q_lane_u8(from + 8*stride, res, 8);
+ res = vld1q_lane_u8(from + 9*stride, res, 9);
+ res = vld1q_lane_u8(from + 10*stride, res, 10);
+ res = vld1q_lane_u8(from + 11*stride, res, 11);
+ res = vld1q_lane_u8(from + 12*stride, res, 12);
+ res = vld1q_lane_u8(from + 13*stride, res, 13);
+ res = vld1q_lane_u8(from + 14*stride, res, 14);
+ res = vld1q_lane_u8(from + 15*stride, res, 15);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pgather<int16_t, Packet4s>(const int16_t* from, Index stride)
+{
+ Packet4s res = vld1_dup_s16(from);
+ res = vld1_lane_s16(from + 1*stride, res, 1);
+ res = vld1_lane_s16(from + 2*stride, res, 2);
+ res = vld1_lane_s16(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pgather<int16_t, Packet8s>(const int16_t* from, Index stride)
+{
+ Packet8s res = vld1q_dup_s16(from);
+ res = vld1q_lane_s16(from + 1*stride, res, 1);
+ res = vld1q_lane_s16(from + 2*stride, res, 2);
+ res = vld1q_lane_s16(from + 3*stride, res, 3);
+ res = vld1q_lane_s16(from + 4*stride, res, 4);
+ res = vld1q_lane_s16(from + 5*stride, res, 5);
+ res = vld1q_lane_s16(from + 6*stride, res, 6);
+ res = vld1q_lane_s16(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pgather<uint16_t, Packet4us>(const uint16_t* from, Index stride)
+{
+ Packet4us res = vld1_dup_u16(from);
+ res = vld1_lane_u16(from + 1*stride, res, 1);
+ res = vld1_lane_u16(from + 2*stride, res, 2);
+ res = vld1_lane_u16(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pgather<uint16_t, Packet8us>(const uint16_t* from, Index stride)
+{
+ Packet8us res = vld1q_dup_u16(from);
+ res = vld1q_lane_u16(from + 1*stride, res, 1);
+ res = vld1q_lane_u16(from + 2*stride, res, 2);
+ res = vld1q_lane_u16(from + 3*stride, res, 3);
+ res = vld1q_lane_u16(from + 4*stride, res, 4);
+ res = vld1q_lane_u16(from + 5*stride, res, 5);
+ res = vld1q_lane_u16(from + 6*stride, res, 6);
+ res = vld1q_lane_u16(from + 7*stride, res, 7);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pgather<int32_t, Packet2i>(const int32_t* from, Index stride)
+{
+ Packet2i res = vld1_dup_s32(from);
+ res = vld1_lane_s32(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pgather<int32_t, Packet4i>(const int32_t* from, Index stride)
+{
+ Packet4i res = vld1q_dup_s32(from);
+ res = vld1q_lane_s32(from + 1*stride, res, 1);
+ res = vld1q_lane_s32(from + 2*stride, res, 2);
+ res = vld1q_lane_s32(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pgather<uint32_t, Packet2ui>(const uint32_t* from, Index stride)
+{
+ Packet2ui res = vld1_dup_u32(from);
+ res = vld1_lane_u32(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pgather<uint32_t, Packet4ui>(const uint32_t* from, Index stride)
+{
+ Packet4ui res = vld1q_dup_u32(from);
+ res = vld1q_lane_u32(from + 1*stride, res, 1);
+ res = vld1q_lane_u32(from + 2*stride, res, 2);
+ res = vld1q_lane_u32(from + 3*stride, res, 3);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pgather<int64_t, Packet2l>(const int64_t* from, Index stride)
+{
+ Packet2l res = vld1q_dup_s64(from);
+ res = vld1q_lane_s64(from + 1*stride, res, 1);
+ return res;
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pgather<uint64_t, Packet2ul>(const uint64_t* from, Index stride)
+{
+ Packet2ul res = vld1q_dup_u64(from);
+ res = vld1q_lane_u64(from + 1*stride, res, 1);
+ return res;
}
-template<> EIGEN_STRONG_INLINE void prefetch<float> (const float* addr) { EIGEN_ARM_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE void prefetch<int32_t>(const int32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<float, Packet2f>(float* to, const Packet2f& from, Index stride)
+{
+ vst1_lane_f32(to + stride*0, from, 0);
+ vst1_lane_f32(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
+{
+ vst1q_lane_f32(to + stride*0, from, 0);
+ vst1q_lane_f32(to + stride*1, from, 1);
+ vst1q_lane_f32(to + stride*2, from, 2);
+ vst1q_lane_f32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet4c>(int8_t* to, const Packet4c& from, Index stride)
+{
+ for (int i = 0; i != 4; i++)
+ *(to + i * stride) = reinterpret_cast<const int8_t*>(&from)[i];
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet8c>(int8_t* to, const Packet8c& from, Index stride)
+{
+ vst1_lane_s8(to + stride*0, from, 0);
+ vst1_lane_s8(to + stride*1, from, 1);
+ vst1_lane_s8(to + stride*2, from, 2);
+ vst1_lane_s8(to + stride*3, from, 3);
+ vst1_lane_s8(to + stride*4, from, 4);
+ vst1_lane_s8(to + stride*5, from, 5);
+ vst1_lane_s8(to + stride*6, from, 6);
+ vst1_lane_s8(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int8_t, Packet16c>(int8_t* to, const Packet16c& from, Index stride)
+{
+ vst1q_lane_s8(to + stride*0, from, 0);
+ vst1q_lane_s8(to + stride*1, from, 1);
+ vst1q_lane_s8(to + stride*2, from, 2);
+ vst1q_lane_s8(to + stride*3, from, 3);
+ vst1q_lane_s8(to + stride*4, from, 4);
+ vst1q_lane_s8(to + stride*5, from, 5);
+ vst1q_lane_s8(to + stride*6, from, 6);
+ vst1q_lane_s8(to + stride*7, from, 7);
+ vst1q_lane_s8(to + stride*8, from, 8);
+ vst1q_lane_s8(to + stride*9, from, 9);
+ vst1q_lane_s8(to + stride*10, from, 10);
+ vst1q_lane_s8(to + stride*11, from, 11);
+ vst1q_lane_s8(to + stride*12, from, 12);
+ vst1q_lane_s8(to + stride*13, from, 13);
+ vst1q_lane_s8(to + stride*14, from, 14);
+ vst1q_lane_s8(to + stride*15, from, 15);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet4uc>(uint8_t* to, const Packet4uc& from, Index stride)
+{
+ for (int i = 0; i != 4; i++)
+ *(to + i * stride) = reinterpret_cast<const uint8_t*>(&from)[i];
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet8uc>(uint8_t* to, const Packet8uc& from, Index stride)
+{
+ vst1_lane_u8(to + stride*0, from, 0);
+ vst1_lane_u8(to + stride*1, from, 1);
+ vst1_lane_u8(to + stride*2, from, 2);
+ vst1_lane_u8(to + stride*3, from, 3);
+ vst1_lane_u8(to + stride*4, from, 4);
+ vst1_lane_u8(to + stride*5, from, 5);
+ vst1_lane_u8(to + stride*6, from, 6);
+ vst1_lane_u8(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint8_t, Packet16uc>(uint8_t* to, const Packet16uc& from, Index stride)
+{
+ vst1q_lane_u8(to + stride*0, from, 0);
+ vst1q_lane_u8(to + stride*1, from, 1);
+ vst1q_lane_u8(to + stride*2, from, 2);
+ vst1q_lane_u8(to + stride*3, from, 3);
+ vst1q_lane_u8(to + stride*4, from, 4);
+ vst1q_lane_u8(to + stride*5, from, 5);
+ vst1q_lane_u8(to + stride*6, from, 6);
+ vst1q_lane_u8(to + stride*7, from, 7);
+ vst1q_lane_u8(to + stride*8, from, 8);
+ vst1q_lane_u8(to + stride*9, from, 9);
+ vst1q_lane_u8(to + stride*10, from, 10);
+ vst1q_lane_u8(to + stride*11, from, 11);
+ vst1q_lane_u8(to + stride*12, from, 12);
+ vst1q_lane_u8(to + stride*13, from, 13);
+ vst1q_lane_u8(to + stride*14, from, 14);
+ vst1q_lane_u8(to + stride*15, from, 15);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int16_t, Packet4s>(int16_t* to, const Packet4s& from, Index stride)
+{
+ vst1_lane_s16(to + stride*0, from, 0);
+ vst1_lane_s16(to + stride*1, from, 1);
+ vst1_lane_s16(to + stride*2, from, 2);
+ vst1_lane_s16(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int16_t, Packet8s>(int16_t* to, const Packet8s& from, Index stride)
+{
+ vst1q_lane_s16(to + stride*0, from, 0);
+ vst1q_lane_s16(to + stride*1, from, 1);
+ vst1q_lane_s16(to + stride*2, from, 2);
+ vst1q_lane_s16(to + stride*3, from, 3);
+ vst1q_lane_s16(to + stride*4, from, 4);
+ vst1q_lane_s16(to + stride*5, from, 5);
+ vst1q_lane_s16(to + stride*6, from, 6);
+ vst1q_lane_s16(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint16_t, Packet4us>(uint16_t* to, const Packet4us& from, Index stride)
+{
+ vst1_lane_u16(to + stride*0, from, 0);
+ vst1_lane_u16(to + stride*1, from, 1);
+ vst1_lane_u16(to + stride*2, from, 2);
+ vst1_lane_u16(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint16_t, Packet8us>(uint16_t* to, const Packet8us& from, Index stride)
+{
+ vst1q_lane_u16(to + stride*0, from, 0);
+ vst1q_lane_u16(to + stride*1, from, 1);
+ vst1q_lane_u16(to + stride*2, from, 2);
+ vst1q_lane_u16(to + stride*3, from, 3);
+ vst1q_lane_u16(to + stride*4, from, 4);
+ vst1q_lane_u16(to + stride*5, from, 5);
+ vst1q_lane_u16(to + stride*6, from, 6);
+ vst1q_lane_u16(to + stride*7, from, 7);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int32_t, Packet2i>(int32_t* to, const Packet2i& from, Index stride)
+{
+ vst1_lane_s32(to + stride*0, from, 0);
+ vst1_lane_s32(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int32_t, Packet4i>(int32_t* to, const Packet4i& from, Index stride)
+{
+ vst1q_lane_s32(to + stride*0, from, 0);
+ vst1q_lane_s32(to + stride*1, from, 1);
+ vst1q_lane_s32(to + stride*2, from, 2);
+ vst1q_lane_s32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint32_t, Packet2ui>(uint32_t* to, const Packet2ui& from, Index stride)
+{
+ vst1_lane_u32(to + stride*0, from, 0);
+ vst1_lane_u32(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint32_t, Packet4ui>(uint32_t* to, const Packet4ui& from, Index stride)
+{
+ vst1q_lane_u32(to + stride*0, from, 0);
+ vst1q_lane_u32(to + stride*1, from, 1);
+ vst1q_lane_u32(to + stride*2, from, 2);
+ vst1q_lane_u32(to + stride*3, from, 3);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<int64_t, Packet2l>(int64_t* to, const Packet2l& from, Index stride)
+{
+ vst1q_lane_s64(to + stride*0, from, 0);
+ vst1q_lane_s64(to + stride*1, from, 1);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<uint64_t, Packet2ul>(uint64_t* to, const Packet2ul& from, Index stride)
+{
+ vst1q_lane_u64(to + stride*0, from, 0);
+ vst1q_lane_u64(to + stride*1, from, 1);
+}
-// FIXME only store the 2 first elements ?
-template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x[4]; vst1q_f32(x, a); return x[0]; }
-template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet4i>(const Packet4i& a) { int32_t EIGEN_ALIGN16 x[4]; vst1q_s32(x, a); return x[0]; }
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int8_t>(const int8_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint8_t>(const uint8_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int16_t>(const int16_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint16_t>(const uint16_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int32_t>(const int32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint32_t>(const uint32_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<int64_t>(const int64_t* addr) { EIGEN_ARM_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint64_t>(const uint64_t* addr) { EIGEN_ARM_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) {
- float32x2_t a_lo, a_hi;
- Packet4f a_r64;
+template<> EIGEN_STRONG_INLINE float pfirst<Packet2f>(const Packet2f& a) { return vget_lane_f32(a,0); }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { return vgetq_lane_f32(a,0); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet4c>(const Packet4c& a) { return static_cast<int8_t>(a & 0xff); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet8c>(const Packet8c& a) { return vget_lane_s8(a,0); }
+template<> EIGEN_STRONG_INLINE int8_t pfirst<Packet16c>(const Packet16c& a) { return vgetq_lane_s8(a,0); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet4uc>(const Packet4uc& a) { return static_cast<uint8_t>(a & 0xff); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet8uc>(const Packet8uc& a) { return vget_lane_u8(a,0); }
+template<> EIGEN_STRONG_INLINE uint8_t pfirst<Packet16uc>(const Packet16uc& a) { return vgetq_lane_u8(a,0); }
+template<> EIGEN_STRONG_INLINE int16_t pfirst<Packet4s>(const Packet4s& a) { return vget_lane_s16(a,0); }
+template<> EIGEN_STRONG_INLINE int16_t pfirst<Packet8s>(const Packet8s& a) { return vgetq_lane_s16(a,0); }
+template<> EIGEN_STRONG_INLINE uint16_t pfirst<Packet4us>(const Packet4us& a) { return vget_lane_u16(a,0); }
+template<> EIGEN_STRONG_INLINE uint16_t pfirst<Packet8us>(const Packet8us& a) { return vgetq_lane_u16(a,0); }
+template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet2i>(const Packet2i& a) { return vget_lane_s32(a,0); }
+template<> EIGEN_STRONG_INLINE int32_t pfirst<Packet4i>(const Packet4i& a) { return vgetq_lane_s32(a,0); }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet2ui>(const Packet2ui& a) { return vget_lane_u32(a,0); }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet4ui>(const Packet4ui& a) { return vgetq_lane_u32(a,0); }
+template<> EIGEN_STRONG_INLINE int64_t pfirst<Packet2l>(const Packet2l& a) { return vgetq_lane_s64(a,0); }
+template<> EIGEN_STRONG_INLINE uint64_t pfirst<Packet2ul>(const Packet2ul& a) { return vgetq_lane_u64(a,0); }
- a_r64 = vrev64q_f32(a);
- a_lo = vget_low_f32(a_r64);
- a_hi = vget_high_f32(a_r64);
- return vcombine_f32(a_hi, a_lo);
+template<> EIGEN_STRONG_INLINE Packet2f preverse(const Packet2f& a) { return vrev64_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
+{
+ const float32x4_t a_r64 = vrev64q_f32(a);
+ return vcombine_f32(vget_high_f32(a_r64), vget_low_f32(a_r64));
}
-template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) {
- int32x2_t a_lo, a_hi;
- Packet4i a_r64;
-
- a_r64 = vrev64q_s32(a);
- a_lo = vget_low_s32(a_r64);
- a_hi = vget_high_s32(a_r64);
- return vcombine_s32(a_hi, a_lo);
+template<> EIGEN_STRONG_INLINE Packet4c preverse(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vrev64_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c preverse(const Packet8c& a) { return vrev64_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a)
+{
+ const int8x16_t a_r64 = vrev64q_s8(a);
+ return vcombine_s8(vget_high_s8(a_r64), vget_low_s8(a_r64));
}
+template<> EIGEN_STRONG_INLINE Packet4uc preverse(const Packet4uc& a)
+{ return vget_lane_u32(vreinterpret_u32_u8(vrev64_u8(vreinterpret_u8_u32(vdup_n_u32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8uc preverse(const Packet8uc& a) { return vrev64_u8(a); }
+template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
+{
+ const uint8x16_t a_r64 = vrev64q_u8(a);
+ return vcombine_u8(vget_high_u8(a_r64), vget_low_u8(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet4s preverse(const Packet4s& a) { return vrev64_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a)
+{
+ const int16x8_t a_r64 = vrev64q_s16(a);
+ return vcombine_s16(vget_high_s16(a_r64), vget_low_s16(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet4us preverse(const Packet4us& a) { return vrev64_u16(a); }
+template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a)
+{
+ const uint16x8_t a_r64 = vrev64q_u16(a);
+ return vcombine_u16(vget_high_u16(a_r64), vget_low_u16(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2i preverse(const Packet2i& a) { return vrev64_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
+{
+ const int32x4_t a_r64 = vrev64q_s32(a);
+ return vcombine_s32(vget_high_s32(a_r64), vget_low_s32(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2ui preverse(const Packet2ui& a) { return vrev64_u32(a); }
+template<> EIGEN_STRONG_INLINE Packet4ui preverse(const Packet4ui& a)
+{
+ const uint32x4_t a_r64 = vrev64q_u32(a);
+ return vcombine_u32(vget_high_u32(a_r64), vget_low_u32(a_r64));
+}
+template<> EIGEN_STRONG_INLINE Packet2l preverse(const Packet2l& a)
+{ return vcombine_s64(vget_high_s64(a), vget_low_s64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2ul preverse(const Packet2ul& a)
+{ return vcombine_u64(vget_high_u64(a), vget_low_u64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2f pabs(const Packet2f& a) { return vabs_f32(a); }
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vabsq_f32(a); }
+template<> EIGEN_STRONG_INLINE Packet4c pabs<Packet4c>(const Packet4c& a)
+{ return vget_lane_s32(vreinterpret_s32_s8(vabs_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); }
+template<> EIGEN_STRONG_INLINE Packet8c pabs(const Packet8c& a) { return vabs_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vabsq_s8(a); }
+template<> EIGEN_STRONG_INLINE Packet4uc pabs(const Packet4uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8uc pabs(const Packet8uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4s pabs(const Packet4s& a) { return vabs_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vabsq_s16(a); }
+template<> EIGEN_STRONG_INLINE Packet4us pabs(const Packet4us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2i pabs(const Packet2i& a) { return vabs_s32(a); }
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vabsq_s32(a); }
+template<> EIGEN_STRONG_INLINE Packet2ui pabs(const Packet2ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4ui pabs(const Packet4ui& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) {
+#if EIGEN_ARCH_ARM64
+ return vabsq_s64(a);
+#else
+ return vcombine_s64(
+ vdup_n_s64((std::abs)(vgetq_lane_s64(a, 0))),
+ vdup_n_s64((std::abs)(vgetq_lane_s64(a, 1))));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
+{ return pfrexp_generic(a,exponent); }
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
+{ return pfrexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pldexp<Packet2f>(const Packet2f& a, const Packet2f& exponent)
+{ return pldexp_generic(a,exponent); }
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent)
+{ return pldexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
- float32x2_t a_lo, a_hi, sum;
+ const float32x2_t sum = vadd_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpadd_f32(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet4c>(const Packet4c& a)
+{
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t sum = vpadd_s8(a_dup, a_dup);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet8c>(const Packet8c& a)
+{
+ int8x8_t sum = vpadd_s8(a,a);
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux<Packet16c>(const Packet16c& a)
+{
+ int8x8_t sum = vadd_s8(vget_low_s8(a), vget_high_s8(a));
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ sum = vpadd_s8(sum, sum);
+ return vget_lane_s8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t sum = vpadd_u8(a_dup, a_dup);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t sum = vpadd_u8(a,a);
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t sum = vadd_u8(vget_low_u8(a), vget_high_u8(a));
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ sum = vpadd_u8(sum, sum);
+ return vget_lane_u8(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t sum = vpadd_s16(a,a);
+ return vget_lane_s16(vpadd_s16(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux<Packet8s>(const Packet8s& a)
+{
+ int16x4_t sum = vadd_s16(vget_low_s16(a), vget_high_s16(a));
+ sum = vpadd_s16(sum, sum);
+ sum = vpadd_s16(sum, sum);
+ return vget_lane_s16(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t sum = vpadd_u16(a,a);
+ return vget_lane_u16(vpadd_u16(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t sum = vadd_u16(vget_low_u16(a), vget_high_u16(a));
+ sum = vpadd_u16(sum, sum);
+ sum = vpadd_u16(sum, sum);
+ return vget_lane_u16(sum, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux<Packet2i>(const Packet2i& a) { return vget_lane_s32(vpadd_s32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE int32_t predux<Packet4i>(const Packet4i& a)
+{
+ const int32x2_t sum = vadd_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpadd_s32(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet2ui>(const Packet2ui& a) { return vget_lane_u32(vpadd_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a)
+{
+ const uint32x2_t sum = vadd_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpadd_u32(sum, sum), 0);
+}
+template<> EIGEN_STRONG_INLINE int64_t predux<Packet2l>(const Packet2l& a)
+{ return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1); }
+template<> EIGEN_STRONG_INLINE uint64_t predux<Packet2ul>(const Packet2ul& a)
+{ return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c predux_half_dowto4(const Packet8c& a)
+{
+ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(a,
+ vreinterpret_s8_s32(vrev64_s32(vreinterpret_s32_s8(a))))), 0);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c predux_half_dowto4(const Packet16c& a)
+{ return vadd_s8(vget_high_s8(a), vget_low_s8(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc predux_half_dowto4(const Packet8uc& a)
+{
+ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(a,
+ vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(a))))), 0);
+}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc predux_half_dowto4(const Packet16uc& a)
+{ return vadd_u8(vget_high_u8(a), vget_low_u8(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s predux_half_dowto4(const Packet8s& a)
+{ return vadd_s16(vget_high_s16(a), vget_low_s16(a)); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us predux_half_dowto4(const Packet8us& a)
+{ return vadd_u16(vget_high_u16(a), vget_low_u16(a)); }
+
+// Other reduction functions:
+// mul
+template<> EIGEN_STRONG_INLINE float predux_mul<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(a, 0) * vget_lane_f32(a, 1); }
+template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
+{ return predux_mul(vmul_f32(vget_low_f32(a), vget_high_f32(a))); }
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet4c>(const Packet4c& a)
+{
+ int8x8_t prod = vreinterpret_s8_s32(vdup_n_s32(a));
+ prod = vmul_s8(prod, vrev16_s8(prod));
+ return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet8c>(const Packet8c& a)
+{
+ int8x8_t prod = vmul_s8(a, vrev16_s8(a));
+ prod = vmul_s8(prod, vrev32_s8(prod));
+ return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 4);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_mul<Packet16c>(const Packet16c& a)
+{ return predux_mul(vmul_s8(vget_low_s8(a), vget_high_s8(a))); }
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet4uc>(const Packet4uc& a)
+{
+ uint8x8_t prod = vreinterpret_u8_u32(vdup_n_u32(a));
+ prod = vmul_u8(prod, vrev16_u8(prod));
+ return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t prod = vmul_u8(a, vrev16_u8(a));
+ prod = vmul_u8(prod, vrev32_u8(prod));
+ return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 4);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_mul<Packet16uc>(const Packet16uc& a)
+{ return predux_mul(vmul_u8(vget_low_u8(a), vget_high_u8(a))); }
+template<> EIGEN_STRONG_INLINE int16_t predux_mul<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t prod = vmul_s16(a, vrev32_s16(a));
+ return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_mul<Packet8s>(const Packet8s& a)
+{
+ int16x4_t prod;
+
+ // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8|
+ prod = vmul_s16(vget_low_s16(a), vget_high_s16(a));
+ // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8|
+ prod = vmul_s16(prod, vrev32_s16(prod));
+ // Multiply |a1*a5*a2*a6*a3*a7*a4*a8|
+ return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_mul<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t prod = vmul_u16(a, vrev32_u16(a));
+ return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_mul<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t prod;
+
+ // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8|
+ prod = vmul_u16(vget_low_u16(a), vget_high_u16(a));
+ // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8|
+ prod = vmul_u16(prod, vrev32_u16(prod));
+ // Multiply |a1*a5*a2*a6*a3*a7*a4*a8|
+ return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_mul<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(a, 0) * vget_lane_s32(a, 1); }
+template<> EIGEN_STRONG_INLINE int32_t predux_mul<Packet4i>(const Packet4i& a)
+{ return predux_mul(vmul_s32(vget_low_s32(a), vget_high_s32(a))); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_mul<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(a, 0) * vget_lane_u32(a, 1); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_mul<Packet4ui>(const Packet4ui& a)
+{ return predux_mul(vmul_u32(vget_low_u32(a), vget_high_u32(a))); }
+template<> EIGEN_STRONG_INLINE int64_t predux_mul<Packet2l>(const Packet2l& a)
+{ return vgetq_lane_s64(a, 0) * vgetq_lane_s64(a, 1); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_mul<Packet2ul>(const Packet2ul& a)
+{ return vgetq_lane_u64(a, 0) * vgetq_lane_u64(a, 1); }
+
+// min
+template<> EIGEN_STRONG_INLINE float predux_min<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(vpmin_f32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+{
+ const float32x2_t min = vmin_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpmin_f32(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet4c>(const Packet4c& a)
+{
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t min = vpmin_s8(a_dup, a_dup);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet8c>(const Packet8c& a)
+{
+ int8x8_t min = vpmin_s8(a,a);
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_min<Packet16c>(const Packet16c& a)
+{
+ int8x8_t min = vmin_s8(vget_low_s8(a), vget_high_s8(a));
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ min = vpmin_s8(min, min);
+ return vget_lane_s8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t min = vpmin_u8(a_dup, a_dup);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t min = vpmin_u8(a,a);
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_min<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t min = vmin_u8(vget_low_u8(a), vget_high_u8(a));
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ min = vpmin_u8(min, min);
+ return vget_lane_u8(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t min = vpmin_s16(a,a);
+ return vget_lane_s16(vpmin_s16(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_min<Packet8s>(const Packet8s& a)
+{
+ int16x4_t min = vmin_s16(vget_low_s16(a), vget_high_s16(a));
+ min = vpmin_s16(min, min);
+ min = vpmin_s16(min, min);
+ return vget_lane_s16(min, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t min = vpmin_u16(a,a);
+ return vget_lane_u16(vpmin_u16(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_min<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t min = vmin_u16(vget_low_u16(a), vget_high_u16(a));
+ min = vpmin_u16(min, min);
+ min = vpmin_u16(min, min);
+ return vget_lane_u16(min, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(vpmin_s32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet4i>(const Packet4i& a)
+{
+ const int32x2_t min = vmin_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpmin_s32(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(vpmin_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a)
+{
+ const uint32x2_t min = vmin_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpmin_u32(min, min), 0);
+}
+template<> EIGEN_STRONG_INLINE int64_t predux_min<Packet2l>(const Packet2l& a)
+{ return (std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_min<Packet2ul>(const Packet2ul& a)
+{ return (std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); }
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- sum = vpadd_f32(a_lo, a_hi);
- sum = vpadd_f32(sum, sum);
- return vget_lane_f32(sum, 0);
+// max
+template<> EIGEN_STRONG_INLINE float predux_max<Packet2f>(const Packet2f& a)
+{ return vget_lane_f32(vpmax_f32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+{
+ const float32x2_t max = vmax_f32(vget_low_f32(a), vget_high_f32(a));
+ return vget_lane_f32(vpmax_f32(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet4c>(const Packet4c& a)
+{
+ const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a));
+ int8x8_t max = vpmax_s8(a_dup, a_dup);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet8c>(const Packet8c& a)
+{
+ int8x8_t max = vpmax_s8(a,a);
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int8_t predux_max<Packet16c>(const Packet16c& a)
+{
+ int8x8_t max = vmax_s8(vget_low_s8(a), vget_high_s8(a));
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ max = vpmax_s8(max, max);
+ return vget_lane_s8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet4uc>(const Packet4uc& a)
+{
+ const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t max = vpmax_u8(a_dup, a_dup);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet8uc>(const Packet8uc& a)
+{
+ uint8x8_t max = vpmax_u8(a,a);
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint8_t predux_max<Packet16uc>(const Packet16uc& a)
+{
+ uint8x8_t max = vmax_u8(vget_low_u8(a), vget_high_u8(a));
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ max = vpmax_u8(max, max);
+ return vget_lane_u8(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet4s>(const Packet4s& a)
+{
+ const int16x4_t max = vpmax_s16(a,a);
+ return vget_lane_s16(vpmax_s16(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE int16_t predux_max<Packet8s>(const Packet8s& a)
+{
+ int16x4_t max = vmax_s16(vget_low_s16(a), vget_high_s16(a));
+ max = vpmax_s16(max, max);
+ max = vpmax_s16(max, max);
+ return vget_lane_s16(max, 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet4us>(const Packet4us& a)
+{
+ const uint16x4_t max = vpmax_u16(a,a);
+ return vget_lane_u16(vpmax_u16(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE uint16_t predux_max<Packet8us>(const Packet8us& a)
+{
+ uint16x4_t max = vmax_u16(vget_low_u16(a), vget_high_u16(a));
+ max = vpmax_u16(max, max);
+ max = vpmax_u16(max, max);
+ return vget_lane_u16(max, 0);
+}
+template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet2i>(const Packet2i& a)
+{ return vget_lane_s32(vpmax_s32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet4i>(const Packet4i& a)
+{
+ const int32x2_t max = vmax_s32(vget_low_s32(a), vget_high_s32(a));
+ return vget_lane_s32(vpmax_s32(max, max), 0);
+}
+template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet2ui>(const Packet2ui& a)
+{ return vget_lane_u32(vpmax_u32(a,a), 0); }
+template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a)
+{
+ const uint32x2_t max = vmax_u32(vget_low_u32(a), vget_high_u32(a));
+ return vget_lane_u32(vpmax_u32(max, max), 0);
}
+template<> EIGEN_STRONG_INLINE int64_t predux_max<Packet2l>(const Packet2l& a)
+{ return (std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); }
+template<> EIGEN_STRONG_INLINE uint64_t predux_max<Packet2ul>(const Packet2ul& a)
+{ return (std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); }
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
{
- float32x4x2_t vtrn1, vtrn2, res1, res2;
- Packet4f sum1, sum2, sum;
+ uint32x2_t tmp = vorr_u32(vget_low_u32( vreinterpretq_u32_f32(x)),
+ vget_high_u32(vreinterpretq_u32_f32(x)));
+ return vget_lane_u32(vpmax_u32(tmp, tmp), 0);
+}
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- vtrn1 = vzipq_f32(vecs[0], vecs[2]);
- vtrn2 = vzipq_f32(vecs[1], vecs[3]);
- res1 = vzipq_f32(vtrn1.val[0], vtrn2.val[0]);
- res2 = vzipq_f32(vtrn1.val[1], vtrn2.val[1]);
+// Helpers for ptranspose.
+namespace detail {
+
+template<typename Packet>
+void zip_in_place(Packet& p1, Packet& p2);
- // Do the addition of the resulting vectors
- sum1 = vaddq_f32(res1.val[0], res1.val[1]);
- sum2 = vaddq_f32(res2.val[0], res2.val[1]);
- sum = vaddq_f32(sum1, sum2);
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2f>(Packet2f& p1, Packet2f& p2) {
+ const float32x2x2_t tmp = vzip_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
- return sum;
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4f>(Packet4f& p1, Packet4f& p2) {
+ const float32x4x2_t tmp = vzipq_f32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
-template<> EIGEN_STRONG_INLINE int32_t predux<Packet4i>(const Packet4i& a)
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8c>(Packet8c& p1, Packet8c& p2) {
+ const int8x8x2_t tmp = vzip_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16c>(Packet16c& p1, Packet16c& p2) {
+ const int8x16x2_t tmp = vzipq_s8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8uc>(Packet8uc& p1, Packet8uc& p2) {
+ const uint8x8x2_t tmp = vzip_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet16uc>(Packet16uc& p1, Packet16uc& p2) {
+ const uint8x16x2_t tmp = vzipq_u8(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2i>(Packet2i& p1, Packet2i& p2) {
+ const int32x2x2_t tmp = vzip_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4i>(Packet4i& p1, Packet4i& p2) {
+ const int32x4x2_t tmp = vzipq_s32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet2ui>(Packet2ui& p1, Packet2ui& p2) {
+ const uint32x2x2_t tmp = vzip_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4ui>(Packet4ui& p1, Packet4ui& p2) {
+ const uint32x4x2_t tmp = vzipq_u32(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4s>(Packet4s& p1, Packet4s& p2) {
+ const int16x4x2_t tmp = vzip_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8s>(Packet8s& p1, Packet8s& p2) {
+ const int16x8x2_t tmp = vzipq_s16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4us>(Packet4us& p1, Packet4us& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet8us>(Packet8us& p1, Packet8us& p2) {
+ const uint16x8x2_t tmp = vzipq_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 2>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 4>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 8>& kernel) {
+ zip_in_place(kernel.packet[0], kernel.packet[4]);
+ zip_in_place(kernel.packet[1], kernel.packet[5]);
+ zip_in_place(kernel.packet[2], kernel.packet[6]);
+ zip_in_place(kernel.packet[3], kernel.packet[7]);
+
+ zip_in_place(kernel.packet[0], kernel.packet[2]);
+ zip_in_place(kernel.packet[1], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[6]);
+ zip_in_place(kernel.packet[5], kernel.packet[7]);
+
+ zip_in_place(kernel.packet[0], kernel.packet[1]);
+ zip_in_place(kernel.packet[2], kernel.packet[3]);
+ zip_in_place(kernel.packet[4], kernel.packet[5]);
+ zip_in_place(kernel.packet[6], kernel.packet[7]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock<Packet, 16>& kernel) {
+ EIGEN_UNROLL_LOOP
+ for (int i=0; i<4; ++i) {
+ const int m = (1 << i);
+ EIGEN_UNROLL_LOOP
+ for (int j=0; j<m; ++j) {
+ const int n = (1 << (3-i));
+ EIGEN_UNROLL_LOOP
+ for (int k=0; k<n; ++k) {
+ const int idx = 2*j*n+k;
+ zip_in_place(kernel.packet[idx], kernel.packet[idx + n]);
+ }
+ }
+ }
+}
+
+} // namespace detail
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2f, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4f, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4c, 4>& kernel)
{
- int32x2_t a_lo, a_hi, sum;
+ const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1));
+ const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1));
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- sum = vpadd_s32(a_lo, a_hi);
- sum = vpadd_s32(sum, sum);
- return vget_lane_s32(sum, 0);
+ const int8x8x2_t zip8 = vzip_s8(a,b);
+ const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1]));
+
+ kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16c, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4uc, 4>& kernel)
{
- int32x4x2_t vtrn1, vtrn2, res1, res2;
- Packet4i sum1, sum2, sum;
+ const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1));
+ const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1));
+
+ const uint8x8x2_t zip8 = vzip_u8(a,b);
+ const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1]));
+
+ kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0);
+ kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1);
+ kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0);
+ kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 16>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16uc, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- vtrn1 = vzipq_s32(vecs[0], vecs[2]);
- vtrn2 = vzipq_s32(vecs[1], vecs[3]);
- res1 = vzipq_s32(vtrn1.val[0], vtrn2.val[0]);
- res2 = vzipq_s32(vtrn1.val[1], vtrn2.val[1]);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8s, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
- // Do the addition of the resulting vectors
- sum1 = vaddq_s32(res1.val[0], res1.val[1]);
- sum2 = vaddq_s32(res2.val[0], res2.val[1]);
- sum = vaddq_s32(sum1, sum2);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 8>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8us, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
- return sum;
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2i, 2>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4i, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2ui, 2>& kernel) {
+ detail::zip_in_place(kernel.packet[0], kernel.packet[1]);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4ui, 4>& kernel) {
+ detail::ptranspose_impl(kernel);
}
-// Other reduction functions:
-// mul
-template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2l, 2>& kernel)
{
- float32x2_t a_lo, a_hi, prod;
+#if EIGEN_ARCH_ARM64
+ const int64x2_t tmp1 = vzip1q_s64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[1] = vzip2q_s64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[0] = tmp1;
+#else
+ const int64x1_t tmp[2][2] = {
+ { vget_low_s64(kernel.packet[0]), vget_high_s64(kernel.packet[0]) },
+ { vget_low_s64(kernel.packet[1]), vget_high_s64(kernel.packet[1]) }
+ };
- // Get a_lo = |a1|a2| and a_hi = |a3|a4|
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- // Get the product of a_lo * a_hi -> |a1*a3|a2*a4|
- prod = vmul_f32(a_lo, a_hi);
- // Multiply prod with its swapped value |a2*a4|a1*a3|
- prod = vmul_f32(prod, vrev64_f32(prod));
+ kernel.packet[0] = vcombine_s64(tmp[0][0], tmp[1][0]);
+ kernel.packet[1] = vcombine_s64(tmp[0][1], tmp[1][1]);
+#endif
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2ul, 2>& kernel)
+{
+#if EIGEN_ARCH_ARM64
+ const uint64x2_t tmp1 = vzip1q_u64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[1] = vzip2q_u64(kernel.packet[0], kernel.packet[1]);
+ kernel.packet[0] = tmp1;
+#else
+ const uint64x1_t tmp[2][2] = {
+ { vget_low_u64(kernel.packet[0]), vget_high_u64(kernel.packet[0]) },
+ { vget_low_u64(kernel.packet[1]), vget_high_u64(kernel.packet[1]) }
+ };
- return vget_lane_f32(prod, 0);
+ kernel.packet[0] = vcombine_u64(tmp[0][0], tmp[1][0]);
+ kernel.packet[1] = vcombine_u64(tmp[0][1], tmp[1][1]);
+#endif
}
-template<> EIGEN_STRONG_INLINE int32_t predux_mul<Packet4i>(const Packet4i& a)
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pselect( const Packet2f& mask, const Packet2f& a, const Packet2f& b)
+{ return vbsl_f32(vreinterpret_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b)
+{ return vbslq_f32(vreinterpretq_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pselect(const Packet8c& mask, const Packet8c& a, const Packet8c& b)
+{ return vbsl_s8(vreinterpret_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pselect(const Packet16c& mask, const Packet16c& a, const Packet16c& b)
+{ return vbslq_s8(vreinterpretq_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pselect(const Packet8uc& mask, const Packet8uc& a, const Packet8uc& b)
+{ return vbsl_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pselect(const Packet16uc& mask, const Packet16uc& a, const Packet16uc& b)
+{ return vbslq_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pselect(const Packet4s& mask, const Packet4s& a, const Packet4s& b)
+{ return vbsl_s16(vreinterpret_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pselect(const Packet8s& mask, const Packet8s& a, const Packet8s& b)
+{ return vbslq_s16(vreinterpretq_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pselect(const Packet4us& mask, const Packet4us& a, const Packet4us& b)
+{ return vbsl_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pselect(const Packet8us& mask, const Packet8us& a, const Packet8us& b)
+{ return vbslq_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pselect(const Packet2i& mask, const Packet2i& a, const Packet2i& b)
+{ return vbsl_s32(vreinterpret_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b)
+{ return vbslq_s32(vreinterpretq_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pselect(const Packet2ui& mask, const Packet2ui& a, const Packet2ui& b)
+{ return vbsl_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pselect(const Packet4ui& mask, const Packet4ui& a, const Packet4ui& b)
+{ return vbslq_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l& mask, const Packet2l& a, const Packet2l& b)
+{ return vbslq_s64(vreinterpretq_u64_s64(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b)
+{ return vbslq_u64(mask, a, b); }
+
+// Use armv8 rounding intinsics if available.
+#if EIGEN_ARCH_ARMV8
+template<> EIGEN_STRONG_INLINE Packet2f print<Packet2f>(const Packet2f& a)
+{ return vrndn_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a)
+{ return vrndnq_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
+{ return vrndm_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{ return vrndmq_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{ return vrndp_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{ return vrndpq_f32(a); }
+
+#else
+
+template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet4f limit = pset1<Packet4f>(static_cast<float>(1<<23));
+ const Packet4f abs_a = pabs(a);
+ Packet4f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet2f limit = pset1<Packet2f>(static_cast<float>(1<<23));
+ const Packet2f abs_a = pabs(a);
+ Packet2f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
{
- int32x2_t a_lo, a_hi, prod;
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If greater, subtract one.
+ Packet4f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
- // Get a_lo = |a1|a2| and a_hi = |a3|a4|
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- // Get the product of a_lo * a_hi -> |a1*a3|a2*a4|
- prod = vmul_s32(a_lo, a_hi);
- // Multiply prod with its swapped value |a2*a4|a1*a3|
- prod = vmul_s32(prod, vrev64_s32(prod));
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0f);
+ Packet2f tmp = print<Packet2f>(a);
+ // If greater, subtract one.
+ Packet2f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
- return vget_lane_s32(prod, 0);
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If smaller, add one.
+ Packet4f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
}
-// min
-template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0);
+ Packet2f tmp = print<Packet2f>(a);
+ // If smaller, add one.
+ Packet2f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+#endif
+
+/**
+ * Computes the integer square root
+ * @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
+ * and tests whether setting that digit to 1 would cause the square of the value to be greater than the argument
+ * value. The algorithm is described in detail here: http://ww1.microchip.com/downloads/en/AppNotes/91040a.pdf .
+ */
+template<> EIGEN_STRONG_INLINE Packet4uc psqrt(const Packet4uc& a) {
+ uint8x8_t x = vreinterpret_u8_u32(vdup_n_u32(a));
+ uint8x8_t res = vdup_n_u8(0);
+ uint8x8_t add = vdup_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x8_t temp = vorr_u8(res, add);
+ res = vbsl_u8(vcge_u8(x, vmul_u8(temp, temp)), temp, res);
+ add = vshr_n_u8(add, 1);
+ }
+ return vget_lane_u32(vreinterpret_u32_u8(res), 0);
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet8uc psqrt(const Packet8uc& a) {
+ uint8x8_t res = vdup_n_u8(0);
+ uint8x8_t add = vdup_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x8_t temp = vorr_u8(res, add);
+ res = vbsl_u8(vcge_u8(a, vmul_u8(temp, temp)), temp, res);
+ add = vshr_n_u8(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet16uc psqrt(const Packet16uc& a) {
+ uint8x16_t res = vdupq_n_u8(0);
+ uint8x16_t add = vdupq_n_u8(0x8);
+ for (int i = 0; i < 4; i++)
+ {
+ const uint8x16_t temp = vorrq_u8(res, add);
+ res = vbslq_u8(vcgeq_u8(a, vmulq_u8(temp, temp)), temp, res);
+ add = vshrq_n_u8(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet4us psqrt(const Packet4us& a) {
+ uint16x4_t res = vdup_n_u16(0);
+ uint16x4_t add = vdup_n_u16(0x80);
+ for (int i = 0; i < 8; i++)
+ {
+ const uint16x4_t temp = vorr_u16(res, add);
+ res = vbsl_u16(vcge_u16(a, vmul_u16(temp, temp)), temp, res);
+ add = vshr_n_u16(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet8us psqrt(const Packet8us& a) {
+ uint16x8_t res = vdupq_n_u16(0);
+ uint16x8_t add = vdupq_n_u16(0x80);
+ for (int i = 0; i < 8; i++)
+ {
+ const uint16x8_t temp = vorrq_u16(res, add);
+ res = vbslq_u16(vcgeq_u16(a, vmulq_u16(temp, temp)), temp, res);
+ add = vshrq_n_u16(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet2ui psqrt(const Packet2ui& a) {
+ uint32x2_t res = vdup_n_u32(0);
+ uint32x2_t add = vdup_n_u32(0x8000);
+ for (int i = 0; i < 16; i++)
+ {
+ const uint32x2_t temp = vorr_u32(res, add);
+ res = vbsl_u32(vcge_u32(a, vmul_u32(temp, temp)), temp, res);
+ add = vshr_n_u32(add, 1);
+ }
+ return res;
+}
+/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
+template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
+ uint32x4_t res = vdupq_n_u32(0);
+ uint32x4_t add = vdupq_n_u32(0x8000);
+ for (int i = 0; i < 16; i++)
+ {
+ const uint32x4_t temp = vorrq_u32(res, add);
+ res = vbslq_u32(vcgeq_u32(a, vmulq_u32(temp, temp)), temp, res);
+ add = vshrq_n_u32(add, 1);
+ }
+ return res;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet4f x = vrsqrteq_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2f x = vrsqrte_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+// Unfortunately vsqrt_f32 is only available for A64.
+#if EIGEN_ARCH_ARM64
+template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);}
+template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
+#else
+template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+#endif
+
+//---------- bfloat16 ----------
+// TODO: Add support for native armv8.6-a bfloat16_t
+
+// TODO: Guard if we have native bfloat16 support
+typedef eigen_packet_wrapper<uint16x4_t, 19> Packet4bf;
+
+template<> struct is_arithmetic<Packet4bf> { enum { value = true }; };
+
+template<> struct packet_traits<bfloat16> : default_packet_traits
{
- float32x2_t a_lo, a_hi, min;
+ typedef Packet4bf type;
+ typedef Packet4bf half;
+ enum
+ {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- min = vpmin_f32(a_lo, a_hi);
- min = vpmin_f32(min, min);
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
- return vget_lane_f32(min, 0);
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
+ };
+};
+
+template<> struct unpacket_traits<Packet4bf>
+{
+ typedef bfloat16 type;
+ typedef Packet4bf half;
+ enum
+ {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+namespace detail {
+template<>
+EIGEN_ALWAYS_INLINE void zip_in_place<Packet4bf>(Packet4bf& p1, Packet4bf& p2) {
+ const uint16x4x2_t tmp = vzip_u16(p1, p2);
+ p1 = tmp.val[0];
+ p2 = tmp.val[1];
}
+} // namespace detail
-template<> EIGEN_STRONG_INLINE int32_t predux_min<Packet4i>(const Packet4i& a)
+EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p)
{
- int32x2_t a_lo, a_hi, min;
+ // See the scalar implemention in BFloat16.h for a comprehensible explanation
+ // of this fast rounding algorithm
+ Packet4ui input = reinterpret_cast<Packet4ui>(p);
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- min = vpmin_s32(a_lo, a_hi);
- min = vpmin_s32(min, min);
-
- return vget_lane_s32(min, 0);
+ // lsb = (input >> 16) & 1
+ Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1));
+
+ // rounding_bias = 0x7fff + lsb
+ Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff));
+
+ // input += rounding_bias
+ input = vaddq_u32(input, rounding_bias);
+
+ // input = input >> 16
+ input = vshrq_n_u32(input, 16);
+
+ // Replace float-nans by bfloat16-nans, that is 0x7fc0
+ const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0);
+ const Packet4ui mask = vceqq_f32(p, p);
+ input = vbslq_u32(mask, input, bf16_nan);
+
+ // output = static_cast<uint16_t>(input)
+ return vmovn_u32(input);
}
-// max
-template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
{
- float32x2_t a_lo, a_hi, max;
+ return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
+}
- a_lo = vget_low_f32(a);
- a_hi = vget_high_f32(a);
- max = vpmax_f32(a_lo, a_hi);
- max = vpmax_f32(max, max);
+EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) {
+ return vmovn_u32(vreinterpretq_u32_f32(p));
+}
- return vget_lane_f32(max, 0);
+template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
+ return pset1<Packet4us>(from.value);
}
-template<> EIGEN_STRONG_INLINE int32_t predux_max<Packet4i>(const Packet4i& a)
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
+ return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<uint16_t>(pfirst<Packet4us>(from)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from)
+{
+ return pload<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from)
+{
+ return ploadu<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet4bf& from)
+{
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf ploaddup<Packet4bf>(const bfloat16* from)
+{
+ return ploaddup<Packet4us>(reinterpret_cast<const uint16_t*>(from));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) {
+ return F32ToBf16(pabs<Packet4f>(Bf16ToF32(a)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<PropagateNumbers, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<PropagateNumbers, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<PropagateNaN, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<PropagateNaN, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmin<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmin<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<PropagateNumbers, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<PropagateNumbers, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<PropagateNaN, Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<PropagateNaN, Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
+ const Packet4bf &b)
+{
+ return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf plset<Packet4bf>(const bfloat16& a)
+{
+ return F32ToBf16(plset<Packet4f>(static_cast<float>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
+ return por<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) {
+ return pxor<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) {
+ return pand<Packet4us>(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) {
+ return pandnot<Packet4us>(a, b);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a,
+ const Packet4bf& b)
+{
+ return pselect<Packet4us>(mask, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf print<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(print<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pceil<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pceil<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(padd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf psub<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(psub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
+ return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<>
+EIGEN_STRONG_INLINE Packet4bf pgather<bfloat16, Packet4bf>(const bfloat16* from, Index stride)
+{
+ return pgather<uint16_t, Packet4us>(reinterpret_cast<const uint16_t*>(from), stride);
+}
+
+template<>
+EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet4bf>(bfloat16* to, const Packet4bf& from, Index stride)
+{
+ pscatter<uint16_t, Packet4us>(reinterpret_cast<uint16_t*>(to), from, stride);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_max<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_min<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet4bf>(const Packet4bf& a)
+{
+ return static_cast<bfloat16>(predux_mul<Packet4f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf preverse<Packet4bf>(const Packet4bf& a)
+{
+ return preverse<Packet4us>(a);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4bf, 4>& kernel)
+{
+ detail::ptranspose_impl(kernel);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- int32x2_t a_lo, a_hi, max;
-
- a_lo = vget_low_s32(a);
- a_hi = vget_high_s32(a);
- max = vpmax_s32(a_lo, a_hi);
- max = vpmax_s32(max, max);
-
- return vget_lane_s32(max, 0);
-}
-
-// this PALIGN_NEON business is to work around a bug in LLVM Clang 3.0 causing incorrect compilation errors,
-// see bug 347 and this LLVM bug: http://llvm.org/bugs/show_bug.cgi?id=11074
-#define PALIGN_NEON(Offset,Type,Command) \
-template<>\
-struct palign_impl<Offset,Type>\
-{\
- EIGEN_STRONG_INLINE static void run(Type& first, const Type& second)\
- {\
- if (Offset!=0)\
- first = Command(first, second, Offset);\
- }\
-};\
-
-PALIGN_NEON(0,Packet4f,vextq_f32)
-PALIGN_NEON(1,Packet4f,vextq_f32)
-PALIGN_NEON(2,Packet4f,vextq_f32)
-PALIGN_NEON(3,Packet4f,vextq_f32)
-PALIGN_NEON(0,Packet4i,vextq_s32)
-PALIGN_NEON(1,Packet4i,vextq_s32)
-PALIGN_NEON(2,Packet4i,vextq_s32)
-PALIGN_NEON(3,Packet4i,vextq_s32)
-
-#undef PALIGN_NEON
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4f,4>& kernel) {
- float32x4x2_t tmp1 = vzipq_f32(kernel.packet[0], kernel.packet[1]);
- float32x4x2_t tmp2 = vzipq_f32(kernel.packet[2], kernel.packet[3]);
-
- kernel.packet[0] = vcombine_f32(vget_low_f32(tmp1.val[0]), vget_low_f32(tmp2.val[0]));
- kernel.packet[1] = vcombine_f32(vget_high_f32(tmp1.val[0]), vget_high_f32(tmp2.val[0]));
- kernel.packet[2] = vcombine_f32(vget_low_f32(tmp1.val[1]), vget_low_f32(tmp2.val[1]));
- kernel.packet[3] = vcombine_f32(vget_high_f32(tmp1.val[1]), vget_high_f32(tmp2.val[1]));
-}
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4i,4>& kernel) {
- int32x4x2_t tmp1 = vzipq_s32(kernel.packet[0], kernel.packet[1]);
- int32x4x2_t tmp2 = vzipq_s32(kernel.packet[2], kernel.packet[3]);
- kernel.packet[0] = vcombine_s32(vget_low_s32(tmp1.val[0]), vget_low_s32(tmp2.val[0]));
- kernel.packet[1] = vcombine_s32(vget_high_s32(tmp1.val[0]), vget_high_s32(tmp2.val[0]));
- kernel.packet[2] = vcombine_s32(vget_low_s32(tmp1.val[1]), vget_low_s32(tmp2.val[1]));
- kernel.packet[3] = vcombine_s32(vget_high_s32(tmp1.val[1]), vget_high_s32(tmp2.val[1]));
+ return F32ToBf16(pabsdiff<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt_or_nan<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_lt_or_nan<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
+{
+ return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)
+{
+ return pxor<Packet4us>(a, pset1<Packet4us>(static_cast<uint16_t>(0x8000)));
}
//---------- double ----------
@@ -540,55 +3642,115 @@ ptranspose(PacketBlock<Packet4i,4>& kernel) {
// Defining these functions as templates ensures that if these intrinsics are
// already defined in arm_neon.h, then our workaround doesn't cause a conflict
// and has lower priority in overload resolution.
-template <typename T>
-uint64x2_t vreinterpretq_u64_f64(T a)
+template <typename T> uint64x2_t vreinterpretq_u64_f64(T a) { return (uint64x2_t) a; }
+
+template <typename T> float64x2_t vreinterpretq_f64_u64(T a) { return (float64x2_t) a; }
+
+typedef float64x2_t Packet2d;
+typedef float64x1_t Packet1d;
+
+// fuctionally equivalent to _mm_shuffle_pd in SSE (i.e. shuffle(m, n, mask) equals _mm_shuffle_pd(m,n,mask))
+// Currently used in LU/arch/InverseSize4.h to enable a shared implementation
+// for fast inversion of matrices of size 4.
+EIGEN_STRONG_INLINE Packet2d shuffle(const Packet2d& m, const Packet2d& n, int mask)
{
- return (uint64x2_t) a;
+ const double* a = reinterpret_cast<const double*>(&m);
+ const double* b = reinterpret_cast<const double*>(&n);
+ Packet2d res = {*(a + (mask & 1)), *(b + ((mask >> 1) & 1))};
+ return res;
}
-template <typename T>
-float64x2_t vreinterpretq_f64_u64(T a)
+EIGEN_STRONG_INLINE Packet2d vec2d_swizzle2(const Packet2d& a, const Packet2d& b, int mask)
{
- return (float64x2_t) a;
+ return shuffle(a, b, mask);
}
-
-typedef float64x2_t Packet2d;
-typedef float64x1_t Packet1d;
+EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a,const Packet2d& b)
+{
+ return shuffle(a, b, 0);
+}
+EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a,const Packet2d& b)
+{
+ return shuffle(a, b, 3);
+}
+#define vec2d_duplane(a, p) \
+ vdupq_laneq_f64(a, p)
template<> struct packet_traits<double> : default_packet_traits
{
typedef Packet2d type;
typedef Packet2d half;
- enum {
+ enum
+ {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 2,
- HasHalfPacket=0,
-
- HasDiv = 1,
- // FIXME check the Has*
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+
HasSin = 0,
HasCos = 0,
- HasLog = 0,
- HasExp = 0,
- HasSqrt = 0
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = 0,
+ HasErf = 0
};
};
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d>
+{
+ typedef double type;
+ typedef Packet2d half;
+ typedef Packet2l integer_packet;
+ enum
+ {
+ size = 2,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) { return vdupq_n_f64(from); }
template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a)
{
- const double countdown_raw[] = {0.0,1.0};
- const Packet2d countdown = vld1q_f64(countdown_raw);
- return vaddq_f64(pset1<Packet2d>(a), countdown);
+ const double c[] = {0.0,1.0};
+ return vaddq_f64(pset1<Packet2d>(a), vld1q_f64(c));
}
+
template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return vaddq_f64(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return vsubq_f64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& , const Packet2d& );
+template<> EIGEN_STRONG_INLINE Packet2d paddsub<Packet2d>(const Packet2d& a, const Packet2d& b){
+ const Packet2d mask = {numext::bit_cast<double>(0x8000000000000000ull),0.0};
+ return padd(a, pxor(mask, b));
+}
+
template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return vnegq_f64(a); }
template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
@@ -599,128 +3761,824 @@ template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const
#ifdef __ARM_FEATURE_FMA
// See bug 936. See above comment about FMA for float.
-template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vfmaq_f64(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c)
+{ return vfmaq_f64(c,a,b); }
#else
-template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vmlaq_f64(c,a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c)
+{ return vmlaq_f64(c,a,b); }
#endif
template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vminq_f64(a,b); }
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet2d pmin<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) { return vminnmq_f64(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) { return vmaxnmq_f64(a, b); }
+
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pmin<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmin<Packet2d>(a, b); }
+
template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vmaxq_f64(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmax<Packet2d>(a, b); }
+
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b)
-{
- return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
-}
+{ return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
-template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vcleq_f64(a,b)); }
-template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f64(from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vcltq_f64(a,b)); }
-template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
-{
- return vld1q_dup_f64(from);
-}
-template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to, from); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u32(vmvnq_u32(vreinterpretq_u32_u64(vcgeq_f64(a,b)))); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b)
+{ return vreinterpretq_f64_u64(vceqq_f64(a,b)); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from)
+{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE vst1q_f64(to, from); }
+template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from)
+{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f64(from); }
-template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride)
+template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from) { return vld1q_dup_f64(from); }
+template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from)
+{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to,from); }
+
+template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from)
+{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f64(to,from); }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pgather<double, Packet2d>(const double* from, Index stride)
{
Packet2d res = pset1<Packet2d>(0.0);
- res = vsetq_lane_f64(from[0*stride], res, 0);
- res = vsetq_lane_f64(from[1*stride], res, 1);
+ res = vld1q_lane_f64(from + 0*stride, res, 0);
+ res = vld1q_lane_f64(from + 1*stride, res, 1);
return res;
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
{
- to[stride*0] = vgetq_lane_f64(from, 0);
- to[stride*1] = vgetq_lane_f64(from, 1);
+ vst1q_lane_f64(to + stride*0, from, 0);
+ vst1q_lane_f64(to + stride*1, from, 1);
}
+
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_ARM_PREFETCH(addr); }
// FIXME only store the 2 first elements ?
-template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(a, 0); }
+template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(a,0); }
-template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return vcombine_f64(vget_high_f64(a), vget_low_f64(a)); }
+template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
+{ return vcombine_f64(vget_high_f64(a), vget_low_f64(a)); }
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vabsq_f64(a); }
#if EIGEN_COMP_CLANG && defined(__apple_build_version__)
// workaround ICE, see bug 907
-template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) { return (vget_low_f64(a) + vget_high_f64(a))[0]; }
+template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
+{ return (vget_low_f64(a) + vget_high_f64(a))[0]; }
#else
-template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) { return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); }
+template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
+{ return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); }
#endif
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- float64x2_t trn1, trn2;
-
- // NEON zip performs interleaving of the supplied vectors.
- // We perform two interleaves in a row to acquire the transposed vector
- trn1 = vzip1q_f64(vecs[0], vecs[1]);
- trn2 = vzip2q_f64(vecs[0], vecs[1]);
-
- // Do the addition of the resulting vectors
- return vaddq_f64(trn1, trn2);
-}
// Other reduction functions:
// mul
#if EIGEN_COMP_CLANG && defined(__apple_build_version__)
-template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) { return (vget_low_f64(a) * vget_high_f64(a))[0]; }
+template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
+{ return (vget_low_f64(a) * vget_high_f64(a))[0]; }
#else
-template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) { return vget_lane_f64(vget_low_f64(a) * vget_high_f64(a), 0); }
+template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
+{ return vget_lane_f64(vget_low_f64(a) * vget_high_f64(a), 0); }
#endif
// min
-template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(vpminq_f64(a, a), 0); }
+template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a)
+{ return vgetq_lane_f64(vpminq_f64(a,a), 0); }
// max
-template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a) { return vgetq_lane_f64(vpmaxq_f64(a, a), 0); }
-
-// this PALIGN_NEON business is to work around a bug in LLVM Clang 3.0 causing incorrect compilation errors,
-// see bug 347 and this LLVM bug: http://llvm.org/bugs/show_bug.cgi?id=11074
-#define PALIGN_NEON(Offset,Type,Command) \
-template<>\
-struct palign_impl<Offset,Type>\
-{\
- EIGEN_STRONG_INLINE static void run(Type& first, const Type& second)\
- {\
- if (Offset!=0)\
- first = Command(first, second, Offset);\
- }\
-};\
-
-PALIGN_NEON(0,Packet2d,vextq_f64)
-PALIGN_NEON(1,Packet2d,vextq_f64)
-#undef PALIGN_NEON
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet2d,2>& kernel) {
- float64x2_t trn1 = vzip1q_f64(kernel.packet[0], kernel.packet[1]);
- float64x2_t trn2 = vzip2q_f64(kernel.packet[0], kernel.packet[1]);
-
- kernel.packet[0] = trn1;
- kernel.packet[1] = trn2;
-}
-#endif // EIGEN_ARCH_ARM64
+template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
+{ return vgetq_lane_f64(vpmaxq_f64(a,a), 0); }
+
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet2d, 2>& kernel)
+{
+ const float64x2_t tmp1 = vzip1q_f64(kernel.packet[0], kernel.packet[1]);
+ const float64x2_t tmp2 = vzip2q_f64(kernel.packet[0], kernel.packet[1]);
+
+ kernel.packet[0] = tmp1;
+ kernel.packet[1] = tmp2;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b)
+{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a)
+{ return vrndnq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{ return vrndmq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
+{ return vrndpq_f64(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
+{ return pldexp_generic(a, exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent)
+{ return pfrexp_generic(a,exponent); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
+{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2d x = vrsqrteq_f64(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ const Packet2d infinity = pset1<Packet2d>(NumTraits<double>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
+
+#endif // EIGEN_ARCH_ARM64
+
+// Do we have an fp16 types and supporting Neon intrinsics?
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+typedef float16x4_t Packet4hf;
+typedef float16x8_t Packet8hf;
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet8hf type;
+ typedef Packet4hf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 1,
+
+ HasCmp = 1,
+ HasCast = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasAbsDiff = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasInsert = 1,
+ HasReduxp = 1,
+ HasDiv = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasSin = 0,
+ HasCos = 0,
+ HasLog = 0,
+ HasExp = 0,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasErf = EIGEN_FAST_MATH,
+ HasBessel = 0, // Issues with accuracy.
+ HasNdtri = 0
+ };
+};
+
+template <>
+struct unpacket_traits<Packet4hf> {
+ typedef Eigen::half type;
+ typedef Packet4hf half;
+ enum {
+ size = 4,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+struct unpacket_traits<Packet8hf> {
+ typedef Eigen::half type;
+ typedef Packet4hf half;
+ enum {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf predux_half_dowto4<Packet8hf>(const Packet8hf& a) {
+ return vadd_f16(vget_low_f16(a), vget_high_f16(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pset1<Packet8hf>(const Eigen::half& from) {
+ return vdupq_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pset1<Packet4hf>(const Eigen::half& from) {
+ return vdup_n_f16(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf plset<Packet8hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ Packet8hf countdown = vld1q_f16(f);
+ return vaddq_f16(pset1<Packet8hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf plset<Packet4hf>(const Eigen::half& a) {
+ const float16_t f[] = {0, 1, 2, 3};
+ Packet4hf countdown = vld1_f16(f);
+ return vadd_f16(pset1<Packet4hf>(a), countdown);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf padd<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vaddq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf padd<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vadd_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psub<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vsubq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psub<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vsub_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) {
+ return vnegq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) {
+ return vneg_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmul<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmulq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmul<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmul_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pdiv<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vdivq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pdiv<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vdiv_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
+ return vfmaq_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
+ return vfma_f16(c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vminq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmin<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmin_f16(a, b);
+}
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4hf pmin<PropagateNumbers, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return vminnm_f16(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8hf pmin<PropagateNumbers, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return vminnmq_f16(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4hf pmin<PropagateNaN, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return pmin<Packet4hf>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet8hf pmin<PropagateNaN, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return pmin<Packet8hf>(a, b); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pmax<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vmaxq_f16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pmax<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vmax_f16(a, b);
+}
+
+#ifdef __ARM_FEATURE_NUMERIC_MAXMIN
+// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems).
+template<> EIGEN_STRONG_INLINE Packet4hf pmax<PropagateNumbers, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return vmaxnm_f16(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8hf pmax<PropagateNumbers, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return vmaxnmq_f16(a, b); }
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4hf pmax<PropagateNaN, Packet4hf>(const Packet4hf& a, const Packet4hf& b) { return pmax<Packet4hf>(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet8hf pmax<PropagateNaN, Packet8hf>(const Packet8hf& a, const Packet8hf& b) { return pmax<Packet8hf>(a, b); }
+
+#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \
+ return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \
+ }
+
+#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \
+ return vreinterpret_f16_u16(vc##name##_f16(a, b)); \
+ }
+
+EIGEN_MAKE_ARM_FP16_CMP_8(eq)
+EIGEN_MAKE_ARM_FP16_CMP_8(lt)
+EIGEN_MAKE_ARM_FP16_CMP_8(le)
+
+EIGEN_MAKE_ARM_FP16_CMP_4(eq)
+EIGEN_MAKE_ARM_FP16_CMP_4(lt)
+EIGEN_MAKE_ARM_FP16_CMP_4(le)
+
+#undef EIGEN_MAKE_ARM_FP16_CMP_8
+#undef EIGEN_MAKE_ARM_FP16_CMP_4
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf print<Packet8hf>(const Packet8hf& a)
+{ return vrndnq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf print<Packet4hf>(const Packet4hf& a)
+{ return vrndn_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a)
+{ return vrndmq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a)
+{ return vrndm_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a)
+{ return vrndpq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a)
+{ return vrndp_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
+ return vsqrtq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf psqrt<Packet4hf>(const Packet4hf& a) {
+ return vsqrt_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pand<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pand<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf por<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf por<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pxor<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pxor<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pandnot<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
+ return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
+ return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadu<Packet8hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploadu<Packet4hf>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploaddup<Packet8hf>(const Eigen::half* from) {
+ Packet8hf packet;
+ packet[0] = from[0].x;
+ packet[1] = from[0].x;
+ packet[2] = from[1].x;
+ packet[3] = from[1].x;
+ packet[4] = from[2].x;
+ packet[5] = from[2].x;
+ packet[6] = from[3].x;
+ packet[7] = from[3].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ploaddup<Packet4hf>(const Eigen::half* from) {
+ float16x4_t packet;
+ float16_t* tmp;
+ tmp = (float16_t*)&packet;
+ tmp[0] = from[0].x;
+ tmp[1] = from[0].x;
+ tmp[2] = from[1].x;
+ tmp[3] = from[1].x;
+ return packet;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ploadquad<Packet8hf>(const Eigen::half* from) {
+ Packet4hf lo, hi;
+ lo = vld1_dup_f16(reinterpret_cast<const float16_t*>(from));
+ hi = vld1_dup_f16(reinterpret_cast<const float16_t*>(from+1));
+ return vcombine_f16(lo, hi);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); }
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); }
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) {
+ return vbslq_f16(vreinterpretq_u16_f16(mask), a, b);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) {
+ return vbsl_f16(vreinterpret_u16_f16(mask), a, b);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); }
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); }
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pgather<Eigen::half, Packet8hf>(const Eigen::half* from, Index stride) {
+ Packet8hf res = pset1<Packet8hf>(Eigen::half(0.f));
+ res = vsetq_lane_f16(from[0 * stride].x, res, 0);
+ res = vsetq_lane_f16(from[1 * stride].x, res, 1);
+ res = vsetq_lane_f16(from[2 * stride].x, res, 2);
+ res = vsetq_lane_f16(from[3 * stride].x, res, 3);
+ res = vsetq_lane_f16(from[4 * stride].x, res, 4);
+ res = vsetq_lane_f16(from[5 * stride].x, res, 5);
+ res = vsetq_lane_f16(from[6 * stride].x, res, 6);
+ res = vsetq_lane_f16(from[7 * stride].x, res, 7);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pgather<Eigen::half, Packet4hf>(const Eigen::half* from, Index stride) {
+ Packet4hf res = pset1<Packet4hf>(Eigen::half(0.f));
+ res = vset_lane_f16(from[0 * stride].x, res, 0);
+ res = vset_lane_f16(from[1 * stride].x, res, 1);
+ res = vset_lane_f16(from[2 * stride].x, res, 2);
+ res = vset_lane_f16(from[3 * stride].x, res, 3);
+ return res;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8hf>(Eigen::half* to, const Packet8hf& from, Index stride) {
+ to[stride * 0].x = vgetq_lane_f16(from, 0);
+ to[stride * 1].x = vgetq_lane_f16(from, 1);
+ to[stride * 2].x = vgetq_lane_f16(from, 2);
+ to[stride * 3].x = vgetq_lane_f16(from, 3);
+ to[stride * 4].x = vgetq_lane_f16(from, 4);
+ to[stride * 5].x = vgetq_lane_f16(from, 5);
+ to[stride * 6].x = vgetq_lane_f16(from, 6);
+ to[stride * 7].x = vgetq_lane_f16(from, 7);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4hf>(Eigen::half* to, const Packet4hf& from, Index stride) {
+ to[stride * 0].x = vget_lane_f16(from, 0);
+ to[stride * 1].x = vget_lane_f16(from, 1);
+ to[stride * 2].x = vget_lane_f16(from, 2);
+ to[stride * 3].x = vget_lane_f16(from, 3);
+}
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<Eigen::half>(const Eigen::half* addr) {
+ EIGEN_ARM_PREFETCH(addr);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8hf>(const Packet8hf& a) {
+ float16_t x[8];
+ vst1q_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4hf>(const Packet4hf& a) {
+ float16_t x[4];
+ vst1_f16(x, a);
+ Eigen::half h;
+ h.x = x[0];
+ return h;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi;
+ Packet8hf a_r64;
+
+ a_r64 = vrev64q_f16(a);
+ a_lo = vget_low_f16(a_r64);
+ a_hi = vget_high_f16(a_r64);
+ return vcombine_f16(a_hi, a_lo);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf preverse<Packet4hf>(const Packet4hf& a) {
+ return vrev64_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pabs<Packet8hf>(const Packet8hf& a) {
+ return vabsq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
+ return vabs_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, sum;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ sum = vpadd_f16(a_lo, a_hi);
+ sum = vpadd_f16(sum, sum);
+ sum = vpadd_f16(sum, sum);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux<Packet4hf>(const Packet4hf& a) {
+ float16x4_t sum;
+
+ sum = vpadd_f16(a, a);
+ sum = vpadd_f16(sum, sum);
+ Eigen::half h;
+ h.x = vget_lane_f16(sum, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, prod;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ prod = vmul_f16(a_lo, a_hi);
+ prod = vmul_f16(prod, vrev64_f16(prod));
+
+ Eigen::half h;
+ h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1));
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) {
+ float16x4_t prod;
+ prod = vmul_f16(a, vrev64_f16(a));
+ Eigen::half h;
+ h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1));
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, min;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ min = vpmin_f16(a_lo, a_hi);
+ min = vpmin_f16(min, min);
+ min = vpmin_f16(min, min);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(min, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmin_f16(a, a);
+ tmp = vpmin_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) {
+ float16x4_t a_lo, a_hi, max;
+
+ a_lo = vget_low_f16(a);
+ a_hi = vget_high_f16(a);
+ max = vpmax_f16(a_lo, a_hi);
+ max = vpmax_f16(max, max);
+ max = vpmax_f16(max, max);
+
+ Eigen::half h;
+ h.x = vget_lane_f16(max, 0);
+ return h;
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) {
+ Packet4hf tmp;
+ tmp = vpmax_f16(a, a);
+ tmp = vpmax_f16(tmp, tmp);
+ Eigen::half h;
+ h.x = vget_lane_f16(tmp, 0);
+ return h;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8hf, 4>& kernel)
+{
+ const float16x8x2_t zip16_1 = vzipq_f16(kernel.packet[0], kernel.packet[1]);
+ const float16x8x2_t zip16_2 = vzipq_f16(kernel.packet[2], kernel.packet[3]);
+
+ const float32x4x2_t zip32_1 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[0]), vreinterpretq_f32_f16(zip16_2.val[0]));
+ const float32x4x2_t zip32_2 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[1]), vreinterpretq_f32_f16(zip16_2.val[1]));
+
+ kernel.packet[0] = vreinterpretq_f16_f32(zip32_1.val[0]);
+ kernel.packet[1] = vreinterpretq_f16_f32(zip32_1.val[1]);
+ kernel.packet[2] = vreinterpretq_f16_f32(zip32_2.val[0]);
+ kernel.packet[3] = vreinterpretq_f16_f32(zip32_2.val[1]);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4hf, 4>& kernel) {
+ EIGEN_ALIGN16 float16x4x4_t tmp_x4;
+ float16_t* tmp = (float16_t*)&kernel;
+ tmp_x4 = vld4_f16(tmp);
+
+ kernel.packet[0] = tmp_x4.val[0];
+ kernel.packet[1] = tmp_x4.val[1];
+ kernel.packet[2] = tmp_x4.val[2];
+ kernel.packet[3] = tmp_x4.val[3];
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8hf, 8>& kernel) {
+ float16x8x2_t T_1[4];
+
+ T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]);
+ T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]);
+ T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]);
+ T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]);
+
+ float16x8x2_t T_2[4];
+ T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]);
+ T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]);
+ T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]);
+ T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]);
+
+ float16x8x2_t T_3[4];
+ T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]);
+ T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]);
+ T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]);
+ T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]);
+
+ kernel.packet[0] = T_3[0].val[0];
+ kernel.packet[1] = T_3[2].val[0];
+ kernel.packet[2] = T_3[1].val[0];
+ kernel.packet[3] = T_3[3].val[0];
+ kernel.packet[4] = T_3[0].val[1];
+ kernel.packet[5] = T_3[2].val[1];
+ kernel.packet[6] = T_3[1].val[1];
+ kernel.packet[7] = T_3[3].val[1];
+}
+#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
} // end namespace internal
diff --git a/Eigen/src/Core/arch/NEON/TypeCasting.h b/Eigen/src/Core/arch/NEON/TypeCasting.h
new file mode 100644
index 000000000..54f97336e
--- /dev/null
+++ b/Eigen/src/Core/arch/NEON/TypeCasting.h
@@ -0,0 +1,1419 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2018 Rasmus Munk Larsen <rmlarsen@google.com>
+// Copyright (C) 2020 Antonio Sanchez <cantonios@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_NEON_H
+#define EIGEN_TYPE_CASTING_NEON_H
+
+namespace Eigen {
+
+namespace internal {
+
+//==============================================================================
+// pcast, SrcType = float
+//==============================================================================
+template <>
+struct type_casting_traits<float, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4f, Packet4f>(const Packet4f& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2f, Packet2f>(const Packet2f& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<float, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+struct type_casting_traits<float, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+// If float64 exists, first convert to that to keep as much precision as possible.
+#if EIGEN_ARCH_ARM64
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4f, Packet2l>(const Packet4f& a) {
+ // Discard second half of input.
+ return vcvtq_s64_f64(vcvt_f64_f32(vget_low_f32(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4f, Packet2ul>(const Packet4f& a) {
+ // Discard second half of input.
+ return vcvtq_u64_f64(vcvt_f64_f32(vget_low_f32(a)));
+}
+#else
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4f, Packet2l>(const Packet4f& a) {
+ // Discard second half of input.
+ return vmovl_s32(vget_low_s32(vcvtq_s32_f32(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4f, Packet2ul>(const Packet4f& a) {
+ // Discard second half of input.
+ return vmovl_u32(vget_low_u32(vcvtq_u32_f32(a)));
+}
+#endif // EIGEN_ARCH_ARM64
+
+template <>
+struct type_casting_traits<float, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return vcvtq_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2f, Packet2i>(const Packet2f& a) {
+ return vcvt_s32_f32(a);
+}
+
+template <>
+struct type_casting_traits<float, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
+ return vcvtq_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2f, Packet2ui>(const Packet2f& a) {
+ return vcvt_u32_f32(a);
+}
+
+template <>
+struct type_casting_traits<float, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4f, Packet8s>(const Packet4f& a, const Packet4f& b) {
+ return vcombine_s16(vmovn_s32(vcvtq_s32_f32(a)), vmovn_s32(vcvtq_s32_f32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2f, Packet4s>(const Packet2f& a, const Packet2f& b) {
+ return vmovn_s32(vcombine_s32(vcvt_s32_f32(a), vcvt_s32_f32(b)));
+}
+
+template <>
+struct type_casting_traits<float, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4f, Packet8us>(const Packet4f& a, const Packet4f& b) {
+ return vcombine_u16(vmovn_u32(vcvtq_u32_f32(a)), vmovn_u32(vcvtq_u32_f32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2f, Packet4us>(const Packet2f& a, const Packet2f& b) {
+ return vmovn_u32(vcombine_u32(vcvt_u32_f32(a), vcvt_u32_f32(b)));
+}
+
+template <>
+struct type_casting_traits<float, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4f, Packet16c>(const Packet4f& a, const Packet4f& b, const Packet4f& c,
+ const Packet4f& d) {
+ const int16x8_t ab_s16 = pcast<Packet4f, Packet8s>(a, b);
+ const int16x8_t cd_s16 = pcast<Packet4f, Packet8s>(c, d);
+ return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2f, Packet8c>(const Packet2f& a, const Packet2f& b, const Packet2f& c,
+ const Packet2f& d) {
+ const int16x4_t ab_s16 = pcast<Packet2f, Packet4s>(a, b);
+ const int16x4_t cd_s16 = pcast<Packet2f, Packet4s>(c, d);
+ return vmovn_s16(vcombine_s16(ab_s16, cd_s16));
+}
+
+template <>
+struct type_casting_traits<float, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4f, Packet16uc>(const Packet4f& a, const Packet4f& b, const Packet4f& c,
+ const Packet4f& d) {
+ const uint16x8_t ab_u16 = pcast<Packet4f, Packet8us>(a, b);
+ const uint16x8_t cd_u16 = pcast<Packet4f, Packet8us>(c, d);
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2f, Packet8uc>(const Packet2f& a, const Packet2f& b, const Packet2f& c,
+ const Packet2f& d) {
+ const uint16x4_t ab_u16 = pcast<Packet2f, Packet4us>(a, b);
+ const uint16x4_t cd_u16 = pcast<Packet2f, Packet4us>(c, d);
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = int8_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int8_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet16c, Packet4f>(const Packet16c& a) {
+ // Discard all but first 4 bytes.
+ return vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a)))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet8c, Packet2f>(const Packet8c& a) {
+ // Discard all but first 2 bytes.
+ return vcvt_f32_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a)))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet16c, Packet2l>(const Packet16c& a) {
+ // Discard all but first two bytes.
+ return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet16c, Packet2ul>(const Packet16c& a) {
+ return vreinterpretq_u64_s64(pcast<Packet16c, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet16c, Packet4i>(const Packet16c& a) {
+ // Discard all but first 4 bytes.
+ return vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet8c, Packet2i>(const Packet8c& a) {
+ // Discard all but first 2 bytes.
+ return vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a))));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet16c, Packet4ui>(const Packet16c& a) {
+ return vreinterpretq_u32_s32(pcast<Packet16c, Packet4i>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet8c, Packet2ui>(const Packet8c& a) {
+ return vreinterpret_u32_s32(pcast<Packet8c, Packet2i>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet16c, Packet8s>(const Packet16c& a) {
+ // Discard second half of input.
+ return vmovl_s8(vget_low_s8(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet8c, Packet4s>(const Packet8c& a) {
+ // Discard second half of input.
+ return vget_low_s16(vmovl_s8(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet16c, Packet8us>(const Packet16c& a) {
+ return vreinterpretq_u16_s16(pcast<Packet16c, Packet8s>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet8c, Packet4us>(const Packet8c& a) {
+ return vreinterpret_u16_s16(pcast<Packet8c, Packet4s>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet16c, Packet16c>(const Packet16c& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet8c, Packet8c>(const Packet8c& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4c pcast<Packet4c, Packet4c>(const Packet4c& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet16c, Packet16uc>(const Packet16c& a) {
+ return vreinterpretq_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet8c, Packet8uc>(const Packet8c& a) {
+ return vreinterpret_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4uc pcast<Packet4c, Packet4uc>(const Packet4c& a) {
+ return static_cast<Packet4uc>(a);
+}
+
+//==============================================================================
+// pcast, SrcType = uint8_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint8_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet16uc, Packet4f>(const Packet16uc& a) {
+ // Discard all but first 4 bytes.
+ return vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a)))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet8uc, Packet2f>(const Packet8uc& a) {
+ // Discard all but first 2 bytes.
+ return vcvt_f32_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a)))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet16uc, Packet2ul>(const Packet16uc& a) {
+ // Discard all but first two bytes.
+ return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet16uc, Packet2l>(const Packet16uc& a) {
+ return vreinterpretq_s64_u64(pcast<Packet16uc, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet16uc, Packet4ui>(const Packet16uc& a) {
+ // Discard all but first 4 bytes.
+ return vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet8uc, Packet2ui>(const Packet8uc& a) {
+ // Discard all but first 2 bytes.
+ return vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a))));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet16uc, Packet4i>(const Packet16uc& a) {
+ return vreinterpretq_s32_u32(pcast<Packet16uc, Packet4ui>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet8uc, Packet2i>(const Packet8uc& a) {
+ return vreinterpret_s32_u32(pcast<Packet8uc, Packet2ui>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet16uc, Packet8us>(const Packet16uc& a) {
+ // Discard second half of input.
+ return vmovl_u8(vget_low_u8(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet8uc, Packet4us>(const Packet8uc& a) {
+ // Discard second half of input.
+ return vget_low_u16(vmovl_u8(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet16uc, Packet8s>(const Packet16uc& a) {
+ return vreinterpretq_s16_u16(pcast<Packet16uc, Packet8us>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet8uc, Packet4s>(const Packet8uc& a) {
+ return vreinterpret_s16_u16(pcast<Packet8uc, Packet4us>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet16uc, Packet16uc>(const Packet16uc& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet8uc, Packet8uc>(const Packet8uc& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4uc pcast<Packet4uc, Packet4uc>(const Packet4uc& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet16uc, Packet16c>(const Packet16uc& a) {
+ return vreinterpretq_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet8uc, Packet8c>(const Packet8uc& a) {
+ return vreinterpret_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4c pcast<Packet4uc, Packet4c>(const Packet4uc& a) {
+ return static_cast<Packet4c>(a);
+}
+
+//==============================================================================
+// pcast, SrcType = int16_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int16_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet8s, Packet4f>(const Packet8s& a) {
+ // Discard second half of input.
+ return vcvtq_f32_s32(vmovl_s16(vget_low_s16(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet4s, Packet2f>(const Packet4s& a) {
+ // Discard second half of input.
+ return vcvt_f32_s32(vget_low_s32(vmovl_s16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet8s, Packet2l>(const Packet8s& a) {
+ // Discard all but first two values.
+ return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(a))));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet8s, Packet2ul>(const Packet8s& a) {
+ return vreinterpretq_u64_s64(pcast<Packet8s, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet8s, Packet4i>(const Packet8s& a) {
+ // Discard second half of input.
+ return vmovl_s16(vget_low_s16(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet4s, Packet2i>(const Packet4s& a) {
+ // Discard second half of input.
+ return vget_low_s32(vmovl_s16(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet8s, Packet4ui>(const Packet8s& a) {
+ return vreinterpretq_u32_s32(pcast<Packet8s, Packet4i>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet4s, Packet2ui>(const Packet4s& a) {
+ return vreinterpret_u32_s32(pcast<Packet4s, Packet2i>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet8s, Packet8s>(const Packet8s& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet4s, Packet4s>(const Packet4s& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet8s, Packet8us>(const Packet8s& a) {
+ return vreinterpretq_u16_s16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet4s, Packet4us>(const Packet4s& a) {
+ return vreinterpret_u16_s16(a);
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet8s, Packet16c>(const Packet8s& a, const Packet8s& b) {
+ return vcombine_s8(vmovn_s16(a), vmovn_s16(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet4s, Packet8c>(const Packet4s& a, const Packet4s& b) {
+ return vmovn_s16(vcombine_s16(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet8s, Packet16uc>(const Packet8s& a, const Packet8s& b) {
+ return vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(a)), vmovn_u16(vreinterpretq_u16_s16(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet4s, Packet8uc>(const Packet4s& a, const Packet4s& b) {
+ return vmovn_u16(vcombine_u16(vreinterpret_u16_s16(a), vreinterpret_u16_s16(b)));
+}
+
+//==============================================================================
+// pcast, SrcType = uint16_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint16_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet8us, Packet4f>(const Packet8us& a) {
+ // Discard second half of input.
+ return vcvtq_f32_u32(vmovl_u16(vget_low_u16(a)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet4us, Packet2f>(const Packet4us& a) {
+ // Discard second half of input.
+ return vcvt_f32_u32(vget_low_u32(vmovl_u16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet8us, Packet2ul>(const Packet8us& a) {
+ // Discard all but first two values.
+ return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(a))));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet8us, Packet2l>(const Packet8us& a) {
+ return vreinterpretq_s64_u64(pcast<Packet8us, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet8us, Packet4ui>(const Packet8us& a) {
+ // Discard second half of input.
+ return vmovl_u16(vget_low_u16(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet4us, Packet2ui>(const Packet4us& a) {
+ // Discard second half of input.
+ return vget_low_u32(vmovl_u16(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet8us, Packet4i>(const Packet8us& a) {
+ return vreinterpretq_s32_u32(pcast<Packet8us, Packet4ui>(a));
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet4us, Packet2i>(const Packet4us& a) {
+ return vreinterpret_s32_u32(pcast<Packet4us, Packet2ui>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet8us, Packet8us>(const Packet8us& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet4us, Packet4us>(const Packet4us& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet8us, Packet8s>(const Packet8us& a) {
+ return vreinterpretq_s16_u16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet4us, Packet4s>(const Packet4us& a) {
+ return vreinterpret_s16_u16(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet8us, Packet16uc>(const Packet8us& a, const Packet8us& b) {
+ return vcombine_u8(vmovn_u16(a), vmovn_u16(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet4us, Packet8uc>(const Packet4us& a, const Packet4us& b) {
+ return vmovn_u16(vcombine_u16(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet8us, Packet16c>(const Packet8us& a, const Packet8us& b) {
+ return vreinterpretq_s8_u8(pcast<Packet8us, Packet16uc>(a, b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet4us, Packet8c>(const Packet4us& a, const Packet4us& b) {
+ return vreinterpret_s8_u8(pcast<Packet4us, Packet8uc>(a, b));
+}
+
+//==============================================================================
+// pcast, SrcType = int32_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int32_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return vcvtq_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2i, Packet2f>(const Packet2i& a) {
+ return vcvt_f32_s32(a);
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4i, Packet2l>(const Packet4i& a) {
+ // Discard second half of input.
+ return vmovl_s32(vget_low_s32(a));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4i, Packet2ul>(const Packet4i& a) {
+ return vreinterpretq_u64_s64(pcast<Packet4i, Packet2l>(a));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4i, Packet4i>(const Packet4i& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2i, Packet2i>(const Packet2i& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4i, Packet4ui>(const Packet4i& a) {
+ return vreinterpretq_u32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2i, Packet2ui>(const Packet2i& a) {
+ return vreinterpret_u32_s32(a);
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4i, Packet8s>(const Packet4i& a, const Packet4i& b) {
+ return vcombine_s16(vmovn_s32(a), vmovn_s32(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2i, Packet4s>(const Packet2i& a, const Packet2i& b) {
+ return vmovn_s32(vcombine_s32(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4i, Packet8us>(const Packet4i& a, const Packet4i& b) {
+ return vcombine_u16(vmovn_u32(vreinterpretq_u32_s32(a)), vmovn_u32(vreinterpretq_u32_s32(b)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2i, Packet4us>(const Packet2i& a, const Packet2i& b) {
+ return vmovn_u32(vreinterpretq_u32_s32(vcombine_s32(a, b)));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4i, Packet16c>(const Packet4i& a, const Packet4i& b, const Packet4i& c,
+ const Packet4i& d) {
+ const int16x8_t ab_s16 = pcast<Packet4i, Packet8s>(a, b);
+ const int16x8_t cd_s16 = pcast<Packet4i, Packet8s>(c, d);
+ return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2i, Packet8c>(const Packet2i& a, const Packet2i& b, const Packet2i& c,
+ const Packet2i& d) {
+ const int16x4_t ab_s16 = vmovn_s32(vcombine_s32(a, b));
+ const int16x4_t cd_s16 = vmovn_s32(vcombine_s32(c, d));
+ return vmovn_s16(vcombine_s16(ab_s16, cd_s16));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4i, Packet16uc>(const Packet4i& a, const Packet4i& b, const Packet4i& c,
+ const Packet4i& d) {
+ const uint16x8_t ab_u16 = pcast<Packet4i, Packet8us>(a, b);
+ const uint16x8_t cd_u16 = pcast<Packet4i, Packet8us>(c, d);
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2i, Packet8uc>(const Packet2i& a, const Packet2i& b, const Packet2i& c,
+ const Packet2i& d) {
+ const uint16x4_t ab_u16 = pcast<Packet2i, Packet4us>(a, b);
+ const uint16x4_t cd_u16 = pcast<Packet2i, Packet4us>(c, d);
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = uint32_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint32_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
+ return vcvtq_f32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f pcast<Packet2ui, Packet2f>(const Packet2ui& a) {
+ return vcvt_f32_u32(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet4ui, Packet2ul>(const Packet4ui& a) {
+ // Discard second half of input.
+ return vmovl_u32(vget_low_u32(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet4ui, Packet2l>(const Packet4ui& a) {
+ return vreinterpretq_s64_u64(pcast<Packet4ui, Packet2ul>(a));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet4ui, Packet4ui>(const Packet4ui& a) {
+ return a;
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui pcast<Packet2ui, Packet2ui>(const Packet2ui& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet4ui, Packet4i>(const Packet4ui& a) {
+ return vreinterpretq_s32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i pcast<Packet2ui, Packet2i>(const Packet2ui& a) {
+ return vreinterpret_s32_u32(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet4ui, Packet8us>(const Packet4ui& a, const Packet4ui& b) {
+ return vcombine_u16(vmovn_u32(a), vmovn_u32(b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4us pcast<Packet2ui, Packet4us>(const Packet2ui& a, const Packet2ui& b) {
+ return vmovn_u32(vcombine_u32(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet4ui, Packet8s>(const Packet4ui& a, const Packet4ui& b) {
+ return vreinterpretq_s16_u16(pcast<Packet4ui, Packet8us>(a, b));
+}
+template <>
+EIGEN_STRONG_INLINE Packet4s pcast<Packet2ui, Packet4s>(const Packet2ui& a, const Packet2ui& b) {
+ return vreinterpret_s16_u16(pcast<Packet2ui, Packet4us>(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet4ui, Packet16uc>(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c,
+ const Packet4ui& d) {
+ const uint16x8_t ab_u16 = vcombine_u16(vmovn_u32(a), vmovn_u32(b));
+ const uint16x8_t cd_u16 = vcombine_u16(vmovn_u32(c), vmovn_u32(d));
+ return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc pcast<Packet2ui, Packet8uc>(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c,
+ const Packet2ui& d) {
+ const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(a, b));
+ const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(c, d));
+ return vmovn_u16(vcombine_u16(ab_u16, cd_u16));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet4ui, Packet16c>(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c,
+ const Packet4ui& d) {
+ return vreinterpretq_s8_u8(pcast<Packet4ui, Packet16uc>(a, b, c, d));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c pcast<Packet2ui, Packet8c>(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c,
+ const Packet2ui& d) {
+ return vreinterpret_s8_u8(pcast<Packet2ui, Packet8uc>(a, b, c, d));
+}
+
+//==============================================================================
+// pcast, SrcType = int64_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::int64_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2l, Packet4f>(const Packet2l& a, const Packet2l& b) {
+ return vcvtq_f32_s32(vcombine_s32(vmovn_s64(a), vmovn_s64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2l, Packet2l>(const Packet2l& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2l, Packet2ul>(const Packet2l& a) {
+ return vreinterpretq_u64_s64(a);
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2l, Packet4i>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_s32(vmovn_s64(a), vmovn_s64(b));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2l, Packet4ui>(const Packet2l& a, const Packet2l& b) {
+ return vcombine_u32(vmovn_u64(vreinterpretq_u64_s64(a)), vmovn_u64(vreinterpretq_u64_s64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2l, Packet8s>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d) {
+ const int32x4_t ab_s32 = pcast<Packet2l, Packet4i>(a, b);
+ const int32x4_t cd_s32 = pcast<Packet2l, Packet4i>(c, d);
+ return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2l, Packet8us>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d) {
+ const uint32x4_t ab_u32 = pcast<Packet2l, Packet4ui>(a, b);
+ const uint32x4_t cd_u32 = pcast<Packet2l, Packet4ui>(c, d);
+ return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2l, Packet16c>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d, const Packet2l& e, const Packet2l& f,
+ const Packet2l& g, const Packet2l& h) {
+ const int16x8_t abcd_s16 = pcast<Packet2l, Packet8s>(a, b, c, d);
+ const int16x8_t efgh_s16 = pcast<Packet2l, Packet8s>(e, f, g, h);
+ return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2l, Packet16uc>(const Packet2l& a, const Packet2l& b, const Packet2l& c,
+ const Packet2l& d, const Packet2l& e, const Packet2l& f,
+ const Packet2l& g, const Packet2l& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2l, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2l, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+//==============================================================================
+// pcast, SrcType = uint64_t
+//==============================================================================
+template <>
+struct type_casting_traits<numext::uint64_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2ul, Packet4f>(const Packet2ul& a, const Packet2ul& b) {
+ return vcvtq_f32_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b)));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2ul, Packet2ul>(const Packet2ul& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2ul, Packet2l>(const Packet2ul& a) {
+ return vreinterpretq_s64_u64(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2ul, Packet4ui>(const Packet2ul& a, const Packet2ul& b) {
+ return vcombine_u32(vmovn_u64(a), vmovn_u64(b));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2ul, Packet4i>(const Packet2ul& a, const Packet2ul& b) {
+ return vreinterpretq_s32_u32(pcast<Packet2ul, Packet4ui>(a, b));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2ul, Packet8us>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d) {
+ const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b)));
+ const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(vmovn_u64(c), vmovn_u64(d)));
+ return vcombine_u16(ab_u16, cd_u16);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2ul, Packet8s>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d) {
+ return vreinterpretq_s16_u16(pcast<Packet2ul, Packet8us>(a, b, c, d));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2ul, Packet16uc>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d, const Packet2ul& e, const Packet2ul& f,
+ const Packet2ul& g, const Packet2ul& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2ul, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2ul, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2ul, Packet16c>(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c,
+ const Packet2ul& d, const Packet2ul& e, const Packet2ul& f,
+ const Packet2ul& g, const Packet2ul& h) {
+ return vreinterpretq_s8_u8(pcast<Packet2ul, Packet16uc>(a, b, c, d, e, f, g, h));
+}
+
+//==============================================================================
+// preinterpret
+//==============================================================================
+template <>
+EIGEN_STRONG_INLINE Packet2f preinterpret<Packet2f, Packet2i>(const Packet2i& a) {
+ return vreinterpret_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2f preinterpret<Packet2f, Packet2ui>(const Packet2ui& a) {
+ return vreinterpret_f32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_f32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet4ui>(const Packet4ui& a) {
+ return vreinterpretq_f32_u32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4c preinterpret<Packet4c, Packet4uc>(const Packet4uc& a) {
+ return static_cast<Packet4c>(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8c preinterpret<Packet8c, Packet8uc>(const Packet8uc& a) {
+ return vreinterpret_s8_u8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16c preinterpret<Packet16c, Packet16uc>(const Packet16uc& a) {
+ return vreinterpretq_s8_u8(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4uc preinterpret<Packet4uc, Packet4c>(const Packet4c& a) {
+ return static_cast<Packet4uc>(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8uc preinterpret<Packet8uc, Packet8c>(const Packet8c& a) {
+ return vreinterpret_u8_s8(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16uc preinterpret<Packet16uc, Packet16c>(const Packet16c& a) {
+ return vreinterpretq_u8_s8(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4s preinterpret<Packet4s, Packet4us>(const Packet4us& a) {
+ return vreinterpret_s16_u16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8us>(const Packet8us& a) {
+ return vreinterpretq_s16_u16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4us preinterpret<Packet4us, Packet4s>(const Packet4s& a) {
+ return vreinterpret_u16_s16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8us preinterpret<Packet8us, Packet8s>(const Packet8s& a) {
+ return vreinterpretq_u16_s16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2i preinterpret<Packet2i, Packet2f>(const Packet2f& a) {
+ return vreinterpret_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2i preinterpret<Packet2i, Packet2ui>(const Packet2ui& a) {
+ return vreinterpret_s32_u32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet4f>(const Packet4f& a) {
+ return vreinterpretq_s32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet4ui>(const Packet4ui& a) {
+ return vreinterpretq_s32_u32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2ui preinterpret<Packet2ui, Packet2f>(const Packet2f& a) {
+ return vreinterpret_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ui preinterpret<Packet2ui, Packet2i>(const Packet2i& a) {
+ return vreinterpret_u32_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet4f>(const Packet4f& a) {
+ return vreinterpretq_u32_f32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_u32_s32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet2ul>(const Packet2ul& a) {
+ return vreinterpretq_s64_u64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul preinterpret<Packet2ul, Packet2l>(const Packet2l& a) {
+ return vreinterpretq_u64_s64(a);
+}
+
+#if EIGEN_ARCH_ARM64
+
+//==============================================================================
+// pcast/preinterpret, Double
+//==============================================================================
+
+template <>
+struct type_casting_traits<double, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2d, Packet2d>(const Packet2d& a) {
+ return a;
+}
+
+template <>
+struct type_casting_traits<double, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet2d, Packet4f>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_f32(vcvt_f32_f64(a), vcvt_f32_f64(b));
+}
+
+template <>
+struct type_casting_traits<double, numext::int64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2l pcast<Packet2d, Packet2l>(const Packet2d& a) {
+ return vcvtq_s64_f64(a);
+}
+
+template <>
+struct type_casting_traits<double, numext::uint64_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2ul pcast<Packet2d, Packet2ul>(const Packet2d& a) {
+ return vcvtq_u64_f64(a);
+}
+
+template <>
+struct type_casting_traits<double, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4i pcast<Packet2d, Packet4i>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_s32(vmovn_s64(vcvtq_s64_f64(a)), vmovn_s64(vcvtq_s64_f64(b)));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet4ui pcast<Packet2d, Packet4ui>(const Packet2d& a, const Packet2d& b) {
+ return vcombine_u32(vmovn_u64(vcvtq_u64_f64(a)), vmovn_u64(vcvtq_u64_f64(b)));
+}
+
+template <>
+struct type_casting_traits<double, numext::int16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet2d, Packet8s>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d) {
+ const int32x4_t ab_s32 = pcast<Packet2d, Packet4i>(a, b);
+ const int32x4_t cd_s32 = pcast<Packet2d, Packet4i>(c, d);
+ return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint16_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet8us pcast<Packet2d, Packet8us>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d) {
+ const uint32x4_t ab_u32 = pcast<Packet2d, Packet4ui>(a, b);
+ const uint32x4_t cd_u32 = pcast<Packet2d, Packet4ui>(c, d);
+ return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32));
+}
+
+template <>
+struct type_casting_traits<double, numext::int8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16c pcast<Packet2d, Packet16c>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d, const Packet2d& e, const Packet2d& f,
+ const Packet2d& g, const Packet2d& h) {
+ const int16x8_t abcd_s16 = pcast<Packet2d, Packet8s>(a, b, c, d);
+ const int16x8_t efgh_s16 = pcast<Packet2d, Packet8s>(e, f, g, h);
+ return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16));
+}
+
+template <>
+struct type_casting_traits<double, numext::uint8_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet16uc pcast<Packet2d, Packet16uc>(const Packet2d& a, const Packet2d& b, const Packet2d& c,
+ const Packet2d& d, const Packet2d& e, const Packet2d& f,
+ const Packet2d& g, const Packet2d& h) {
+ const uint16x8_t abcd_u16 = pcast<Packet2d, Packet8us>(a, b, c, d);
+ const uint16x8_t efgh_u16 = pcast<Packet2d, Packet8us>(e, f, g, h);
+ return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16));
+}
+
+template <>
+struct type_casting_traits<float, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f& a) {
+ // Discard second-half of input.
+ return vcvt_f64_f32(vget_low_f32(a));
+}
+
+template <>
+struct type_casting_traits<numext::int8_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet16c, Packet2d>(const Packet16c& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet8c, Packet2f>(vget_low_s8(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint8_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet16uc, Packet2d>(const Packet16uc& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet8uc, Packet2f>(vget_low_u8(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int16_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet8s, Packet2d>(const Packet8s& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet4s, Packet2f>(vget_low_s16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint16_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet8us, Packet2d>(const Packet8us& a) {
+ // Discard all but first two values.
+ return vcvt_f64_f32(pcast<Packet4us, Packet2f>(vget_low_u16(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int32_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4i, Packet2d>(const Packet4i& a) {
+ // Discard second half of input.
+ return vcvtq_f64_s64(vmovl_s32(vget_low_s32(a)));
+}
+
+template <>
+struct type_casting_traits<numext::uint32_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet4ui, Packet2d>(const Packet4ui& a) {
+ // Discard second half of input.
+ return vcvtq_f64_u64(vmovl_u32(vget_low_u32(a)));
+}
+
+template <>
+struct type_casting_traits<numext::int64_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2l, Packet2d>(const Packet2l& a) {
+ return vcvtq_f64_s64(a);
+}
+
+template <>
+struct type_casting_traits<numext::uint64_t, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+template <>
+EIGEN_STRONG_INLINE Packet2d pcast<Packet2ul, Packet2d>(const Packet2ul& a) {
+ return vcvtq_f64_u64(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet2l>(const Packet2l& a) {
+ return vreinterpretq_f64_s64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet2ul>(const Packet2ul& a) {
+ return vreinterpretq_f64_u64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_s64_f64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2ul preinterpret<Packet2ul, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_u64_f64(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4i>(const Packet4i& a) {
+ return vreinterpretq_f64_s32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet2d>(const Packet2d& a) {
+ return vreinterpretq_s32_f64(a);
+}
+
+#endif // EIGEN_ARCH_ARM64
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_NEON_H
diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h
index 5607fe0ab..8fe22da46 100644
--- a/Eigen/src/Core/arch/SSE/Complex.h
+++ b/Eigen/src/Core/arch/SSE/Complex.h
@@ -19,7 +19,7 @@ struct Packet2cf
{
EIGEN_STRONG_INLINE Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const __m128& a) : v(a) {}
- __m128 v;
+ Packet4f v;
};
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
@@ -40,20 +40,33 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
HasMax = 0,
HasSetLinear = 0,
- HasBlend = 1
+ HasBlend = 1
};
};
#endif
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> {
+ typedef std::complex<float> type;
+ typedef Packet2cf half;
+ typedef Packet4f as_real;
+ enum {
+ size=2,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); }
+
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a)
{
const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x80000000,0x80000000,0x80000000));
@@ -82,10 +95,11 @@ template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, con
#endif
}
+template<> EIGEN_STRONG_INLINE Packet2cf ptrue <Packet2cf>(const Packet2cf& a) { return Packet2cf(ptrue(Packet4f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_and_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_or_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_xor_ps(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_andnot_ps(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_andnot_ps(b.v,a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pload <Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>(&numext::real_ref(*from))); }
template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>(&numext::real_ref(*from))); }
@@ -93,19 +107,13 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<fl
template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
Packet2cf res;
-#if EIGEN_GNUC_AT_MOST(4,2)
- // Workaround annoying "may be used uninitialized in this function" warning with gcc 4.2
- res.v = _mm_loadl_pi(_mm_set1_ps(0.0f), reinterpret_cast<const __m64*>(&from));
-#elif EIGEN_GNUC_AT_LEAST(4,6)
- // Suppress annoying "may be used uninitialized in this function" warning with gcc >= 4.6
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wuninitialized"
- res.v = _mm_loadl_pi(res.v, (const __m64*)&from);
- #pragma GCC diagnostic pop
+#ifdef EIGEN_VECTORIZE_SSE3
+ res.v = _mm_castpd_ps(_mm_loaddup_pd(reinterpret_cast<double const*>(&from)));
#else
- res.v = _mm_loadl_pi(res.v, (const __m64*)&from);
+ res.v = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<double const*>(&from)));
+ res.v = _mm_movelh_ps(res.v, res.v);
#endif
- return Packet2cf(_mm_movelh_ps(res.v,res.v));
+ return res;
}
template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
@@ -128,7 +136,7 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf
_mm_cvtss_f32(_mm_shuffle_ps(from.v, from.v, 3)));
}
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
{
@@ -152,113 +160,26 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packe
return pfirst(Packet2cf(_mm_add_ps(a.v, _mm_movehl_ps(a.v,a.v))));
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
-{
- return Packet2cf(_mm_add_ps(_mm_movelh_ps(vecs[0].v,vecs[1].v), _mm_movehl_ps(vecs[1].v,vecs[0].v)));
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
return pfirst(pmul(a, Packet2cf(_mm_movehl_ps(a.v,a.v))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
- first.v = _mm_movehl_ps(first.v, first.v);
- first.v = _mm_movelh_ps(first.v, second.v);
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(a, pconj(b));
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_add_ps(_mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), mask),
- _mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(pconj(a), b);
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_add_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v),
- _mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2)), mask)));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return pconj(internal::pmul(a, b));
- #else
- const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000));
- return Packet2cf(_mm_sub_ps(_mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), mask),
- _mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3),
- vec4f_swizzle1(b.v, 1, 0, 3, 2))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet4f, Packet2cf, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet4f& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet4f& x, const Packet2cf& y) const
- { return Packet2cf(Eigen::internal::pmul<Packet4f>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet2cf, Packet4f, false,false>
+EIGEN_STRONG_INLINE Packet2cf pcplxflip/* <Packet2cf> */(const Packet2cf& x)
{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet4f& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
+ return Packet2cf(vec4f_swizzle1(x.v, 1, 0, 3, 2));
+}
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& x, const Packet4f& y) const
- { return Packet2cf(Eigen::internal::pmul<Packet4f>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for SSE3 and 4
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a,b);
+ Packet2cf res = pmul(a, pconj(b));
__m128 s = _mm_mul_ps(b.v,b.v);
- return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(s), 0xb1)))));
+ return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,vec4f_swizzle1(s, 1, 0, 3, 2))));
}
-EIGEN_STRONG_INLINE Packet2cf pcplxflip/* <Packet2cf> */(const Packet2cf& x)
-{
- return Packet2cf(vec4f_swizzle1(x.v, 1, 0, 3, 2));
-}
//---------- double ----------
@@ -266,7 +187,7 @@ struct Packet1cd
{
EIGEN_STRONG_INLINE Packet1cd() {}
EIGEN_STRONG_INLINE explicit Packet1cd(const __m128d& a) : v(a) {}
- __m128d v;
+ Packet2d v;
};
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
@@ -287,6 +208,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
+ HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@@ -296,7 +218,18 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
#endif
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> {
+ typedef std::complex<double> type;
+ typedef Packet1cd half;
+ typedef Packet2d as_real;
+ enum {
+ size=1,
+ alignment=Aligned16,
+ vectorizable=true,
+ masked_load_available=false,
+ masked_store_available=false
+ };
+};
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); }
@@ -321,10 +254,11 @@ template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, con
#endif
}
+template<> EIGEN_STRONG_INLINE Packet1cd ptrue <Packet1cd>(const Packet1cd& a) { return Packet1cd(ptrue(Packet2d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_and_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_or_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_xor_pd(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_andnot_pd(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_andnot_pd(b.v,a.v)); }
// FIXME force unaligned load, this is a temporary fix
template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from)
@@ -340,7 +274,7 @@ template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<
template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, Packet2d(from.v)); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, Packet2d(from.v)); }
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> * addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> * addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
@@ -356,102 +290,17 @@ template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Pack
return pfirst(a);
}
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs)
-{
- return vecs[0];
-}
-
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a)
{
return pfirst(a);
}
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(a, pconj(b));
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_add_pd(_mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v), mask),
- _mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return internal::pmul(pconj(a), b);
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_add_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v),
- _mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0)), mask)));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- #ifdef EIGEN_VECTORIZE_SSE3
- return pconj(internal::pmul(a, b));
- #else
- const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
- return Packet1cd(_mm_sub_pd(_mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v), mask),
- _mm_mul_pd(vec2d_swizzle1(a.v, 1, 1),
- vec2d_swizzle1(b.v, 1, 0))));
- #endif
- }
-};
-
-template<> struct conj_helper<Packet2d, Packet1cd, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet2d& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet2d& x, const Packet1cd& y) const
- { return Packet1cd(Eigen::internal::pmul<Packet2d>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet1cd, Packet2d, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet2d& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& x, const Packet2d& y) const
- { return Packet1cd(Eigen::internal::pmul<Packet2d>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for SSE3 and 4
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
__m128d s = _mm_mul_pd(b.v,b.v);
return Packet1cd(_mm_div_pd(res.v, _mm_add_pd(s,_mm_shuffle_pd(s, s, 0x1))));
}
@@ -471,33 +320,32 @@ ptranspose(PacketBlock<Packet2cf,2>& kernel) {
kernel.packet[1].v = tmp;
}
-template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
- __m128d result = pblend<Packet2d>(ifPacket, _mm_castps_pd(thenPacket.v), _mm_castps_pd(elsePacket.v));
- return Packet2cf(_mm_castpd_ps(result));
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b)
+{
+ __m128 eq = _mm_cmpeq_ps(a.v, b.v);
+ return Packet2cf(pand<Packet4f>(eq, vec4f_swizzle1(eq, 1, 0, 3, 2)));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pinsertfirst(const Packet2cf& a, std::complex<float> b)
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b)
{
- return Packet2cf(_mm_loadl_pi(a.v, reinterpret_cast<const __m64*>(&b)));
+ __m128d eq = _mm_cmpeq_pd(a.v, b.v);
+ return Packet1cd(pand<Packet2d>(eq, vec2d_swizzle1(eq, 1, 0)));
}
-template<> EIGEN_STRONG_INLINE Packet1cd pinsertfirst(const Packet1cd&, std::complex<double> b)
-{
- return pset1<Packet1cd>(b);
+template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
+ __m128d result = pblend<Packet2d>(ifPacket, _mm_castps_pd(thenPacket.v), _mm_castps_pd(elsePacket.v));
+ return Packet2cf(_mm_castpd_ps(result));
}
-template<> EIGEN_STRONG_INLINE Packet2cf pinsertlast(const Packet2cf& a, std::complex<float> b)
-{
- return Packet2cf(_mm_loadh_pi(a.v, reinterpret_cast<const __m64*>(&b)));
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
+ return psqrt_complex<Packet1cd>(a);
}
-template<> EIGEN_STRONG_INLINE Packet1cd pinsertlast(const Packet1cd&, std::complex<double> b)
-{
- return pset1<Packet1cd>(b);
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
+ return psqrt_complex<Packet2cf>(a);
}
} // end namespace internal
-
} // end namespace Eigen
#endif // EIGEN_COMPLEX_SSE_H
diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h
index 7b5f948e1..8736d0d6b 100644
--- a/Eigen/src/Core/arch/SSE/MathFunctions.h
+++ b/Eigen/src/Core/arch/SSE/MathFunctions.h
@@ -8,7 +8,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
+/* The sin and cos and functions of this file come from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
@@ -20,426 +20,57 @@ namespace Eigen {
namespace internal {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f plog<Packet4f>(const Packet4f& _x)
-{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
-
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
- /* the smallest non denormalized float number */
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000);//-1.f/0.f);
-
- /* natural logarithm computed for 4 simultaneous float
- return NaN for x <= 0
- */
- _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
-
-
- Packet4i emm0;
-
- Packet4f invalid_mask = _mm_cmpnge_ps(x, _mm_setzero_ps()); // not greater equal is true if x is NaN
- Packet4f iszero_mask = _mm_cmpeq_ps(x, _mm_setzero_ps());
-
- x = pmax(x, p4f_min_norm_pos); /* cut off denormalized stuff */
- emm0 = _mm_srli_epi32(_mm_castps_si128(x), 23);
-
- /* keep only the fractional part */
- x = _mm_and_ps(x, p4f_inv_mant_mask);
- x = _mm_or_ps(x, p4f_half);
-
- emm0 = _mm_sub_epi32(emm0, p4i_0x7f);
- Packet4f e = padd(Packet4f(_mm_cvtepi32_ps(emm0)), p4f_1);
-
- /* part2:
- if( x < SQRTHF ) {
- e -= 1;
- x = x + x - 1.0;
- } else { x = x - 1.0; }
- */
- Packet4f mask = _mm_cmplt_ps(x, p4f_cephes_SQRTHF);
- Packet4f tmp = pand(x, mask);
- x = psub(x, p4f_1);
- e = psub(e, pand(p4f_1, mask));
- x = padd(x, tmp);
-
- Packet4f x2 = pmul(x,x);
- Packet4f x3 = pmul(x2,x);
-
- Packet4f y, y1, y2;
- y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
- y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
- y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
- y = pmadd(y , x, p4f_cephes_log_p2);
- y1 = pmadd(y1, x, p4f_cephes_log_p5);
- y2 = pmadd(y2, x, p4f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- y1 = pmul(e, p4f_cephes_log_q1);
- tmp = pmul(x2, p4f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p4f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
- // negative arg will be NAN, 0 will be -INF
- return _mm_or_ps(_mm_andnot_ps(iszero_mask, _mm_or_ps(x, invalid_mask)),
- _mm_and_ps(iszero_mask, p4f_minus_inf));
+Packet4f plog<Packet4f>(const Packet4f& _x) {
+ return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f pexp<Packet4f>(const Packet4f& _x)
-{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
- _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
-
-
- _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
- _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
-
- _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
-
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
-
- Packet4f tmp, fx;
- Packet4i emm0;
+Packet2d plog<Packet2d>(const Packet2d& _x) {
+ return plog_double(_x);
+}
- // clamp x
- x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f plog2<Packet4f>(const Packet4f& _x) {
+ return plog2_float(_x);
+}
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet2d plog2<Packet2d>(const Packet2d& _x) {
+ return plog2_double(_x);
+}
-#ifdef EIGEN_VECTORIZE_SSE4_1
- fx = _mm_floor_ps(fx);
-#else
- emm0 = _mm_cvttps_epi32(fx);
- tmp = _mm_cvtepi32_ps(emm0);
- /* if greater, substract 1 */
- Packet4f mask = _mm_cmpgt_ps(tmp, fx);
- mask = _mm_and_ps(mask, p4f_1);
- fx = psub(tmp, mask);
-#endif
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f plog1p<Packet4f>(const Packet4f& _x) {
+ return generic_plog1p(_x);
+}
- tmp = pmul(fx, p4f_cephes_exp_C1);
- Packet4f z = pmul(fx, p4f_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- z = pmul(x,x);
-
- Packet4f y = p4f_cephes_exp_p0;
- y = pmadd(y, x, p4f_cephes_exp_p1);
- y = pmadd(y, x, p4f_cephes_exp_p2);
- y = pmadd(y, x, p4f_cephes_exp_p3);
- y = pmadd(y, x, p4f_cephes_exp_p4);
- y = pmadd(y, x, p4f_cephes_exp_p5);
- y = pmadd(y, z, x);
- y = padd(y, p4f_1);
-
- // build 2^n
- emm0 = _mm_cvttps_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_0x7f);
- emm0 = _mm_slli_epi32(emm0, 23);
- return pmax(pmul(y, Packet4f(_mm_castsi128_ps(emm0))), _x);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f pexpm1<Packet4f>(const Packet4f& _x) {
+ return generic_expm1(_x);
}
+
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet2d pexp<Packet2d>(const Packet2d& _x)
+Packet4f pexp<Packet4f>(const Packet4f& _x)
{
- Packet2d x = _x;
-
- _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0);
- _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0);
- _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
-
- _EIGEN_DECLARE_CONST_Packet2d(exp_hi, 709.437);
- _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -709.436139303);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
-
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
- _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
- static const __m128i p4i_1023_0 = _mm_setr_epi32(1023, 1023, 0, 0);
-
- Packet2d tmp, fx;
- Packet4i emm0;
-
- // clamp x
- x = pmax(pmin(x, p2d_exp_hi), p2d_exp_lo);
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(p2d_cephes_LOG2EF, x, p2d_half);
-
-#ifdef EIGEN_VECTORIZE_SSE4_1
- fx = _mm_floor_pd(fx);
-#else
- emm0 = _mm_cvttpd_epi32(fx);
- tmp = _mm_cvtepi32_pd(emm0);
- /* if greater, substract 1 */
- Packet2d mask = _mm_cmpgt_pd(tmp, fx);
- mask = _mm_and_pd(mask, p2d_1);
- fx = psub(tmp, mask);
-#endif
-
- tmp = pmul(fx, p2d_cephes_exp_C1);
- Packet2d z = pmul(fx, p2d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet2d x2 = pmul(x,x);
-
- Packet2d px = p2d_cephes_exp_p0;
- px = pmadd(px, x2, p2d_cephes_exp_p1);
- px = pmadd(px, x2, p2d_cephes_exp_p2);
- px = pmul (px, x);
-
- Packet2d qx = p2d_cephes_exp_q0;
- qx = pmadd(qx, x2, p2d_cephes_exp_q1);
- qx = pmadd(qx, x2, p2d_cephes_exp_q2);
- qx = pmadd(qx, x2, p2d_cephes_exp_q3);
-
- x = pdiv(px,psub(qx,px));
- x = pmadd(p2d_2,x,p2d_1);
-
- // build 2^n
- emm0 = _mm_cvttpd_epi32(fx);
- emm0 = _mm_add_epi32(emm0, p4i_1023_0);
- emm0 = _mm_slli_epi32(emm0, 20);
- emm0 = _mm_shuffle_epi32(emm0, _MM_SHUFFLE(1,2,0,3));
- return pmax(pmul(x, Packet2d(_mm_castsi128_pd(emm0))), _x);
+ return pexp_float(_x);
}
-/* evaluation of 4 sines at onces, using SSE2 intrinsics.
-
- The code is the exact rewriting of the cephes sinf function.
- Precision is excellent as long as x < 8192 (I did not bother to
- take into account the special handling they have for greater values
- -- it does not return garbage for arguments over 8192, though, but
- the extra precision is missing).
-
- Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
- surprising but correct result.
-*/
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet2d pexp<Packet2d>(const Packet2d& x)
+{
+ return pexp_double(x);
+}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psin<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-
- _EIGEN_DECLARE_CONST_Packet4i(1, 1);
- _EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
- _EIGEN_DECLARE_CONST_Packet4i(2, 2);
- _EIGEN_DECLARE_CONST_Packet4i(4, 4);
-
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(sign_mask, 0x80000000);
-
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1,-0.78515625f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
-
- Packet4f xmm1, xmm2, xmm3, sign_bit, y;
-
- Packet4i emm0, emm2;
- sign_bit = x;
- /* take the absolute value */
- x = pabs(x);
-
- /* take the modulo */
-
- /* extract the sign bit (upper one) */
- sign_bit = _mm_and_ps(sign_bit, p4f_sign_mask);
-
- /* scale by 4/Pi */
- y = pmul(x, p4f_cephes_FOPI);
-
- /* store the integer part of y in mm0 */
- emm2 = _mm_cvttps_epi32(y);
- /* j=(j+1) & (~1) (see the cephes sources) */
- emm2 = _mm_add_epi32(emm2, p4i_1);
- emm2 = _mm_and_si128(emm2, p4i_not1);
- y = _mm_cvtepi32_ps(emm2);
- /* get the swap sign flag */
- emm0 = _mm_and_si128(emm2, p4i_4);
- emm0 = _mm_slli_epi32(emm0, 29);
- /* get the polynom selection mask
- there is one polynom for 0 <= x <= Pi/4
- and another one for Pi/4<x<=Pi/2
-
- Both branches will be computed.
- */
- emm2 = _mm_and_si128(emm2, p4i_2);
- emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
-
- Packet4f swap_sign_bit = _mm_castsi128_ps(emm0);
- Packet4f poly_mask = _mm_castsi128_ps(emm2);
- sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
-
- /* The magic pass: "Extended precision modular arithmetic"
- x = ((x - y * DP1) - y * DP2) - y * DP3; */
- xmm1 = pmul(y, p4f_minus_cephes_DP1);
- xmm2 = pmul(y, p4f_minus_cephes_DP2);
- xmm3 = pmul(y, p4f_minus_cephes_DP3);
- x = padd(x, xmm1);
- x = padd(x, xmm2);
- x = padd(x, xmm3);
-
- /* Evaluate the first polynom (0 <= x <= Pi/4) */
- y = p4f_coscof_p0;
- Packet4f z = _mm_mul_ps(x,x);
-
- y = pmadd(y, z, p4f_coscof_p1);
- y = pmadd(y, z, p4f_coscof_p2);
- y = pmul(y, z);
- y = pmul(y, z);
- Packet4f tmp = pmul(z, p4f_half);
- y = psub(y, tmp);
- y = padd(y, p4f_1);
-
- /* Evaluate the second polynom (Pi/4 <= x <= 0) */
-
- Packet4f y2 = p4f_sincof_p0;
- y2 = pmadd(y2, z, p4f_sincof_p1);
- y2 = pmadd(y2, z, p4f_sincof_p2);
- y2 = pmul(y2, z);
- y2 = pmul(y2, x);
- y2 = padd(y2, x);
-
- /* select the correct result from the two polynoms */
- y2 = _mm_and_ps(poly_mask, y2);
- y = _mm_andnot_ps(poly_mask, y);
- y = _mm_or_ps(y,y2);
- /* update the sign */
- return _mm_xor_ps(y, sign_bit);
+ return psin_float(_x);
}
-/* almost the same as psin */
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pcos<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
- _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
- _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-
- _EIGEN_DECLARE_CONST_Packet4i(1, 1);
- _EIGEN_DECLARE_CONST_Packet4i(not1, ~1);
- _EIGEN_DECLARE_CONST_Packet4i(2, 2);
- _EIGEN_DECLARE_CONST_Packet4i(4, 4);
-
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1,-0.78515625f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f);
- _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891E-4f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736E-3f);
- _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611E-1f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948E-005f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765E-003f);
- _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827E-002f);
- _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4 / M_PI
-
- Packet4f xmm1, xmm2, xmm3, y;
- Packet4i emm0, emm2;
-
- x = pabs(x);
-
- /* scale by 4/Pi */
- y = pmul(x, p4f_cephes_FOPI);
-
- /* get the integer part of y */
- emm2 = _mm_cvttps_epi32(y);
- /* j=(j+1) & (~1) (see the cephes sources) */
- emm2 = _mm_add_epi32(emm2, p4i_1);
- emm2 = _mm_and_si128(emm2, p4i_not1);
- y = _mm_cvtepi32_ps(emm2);
-
- emm2 = _mm_sub_epi32(emm2, p4i_2);
-
- /* get the swap sign flag */
- emm0 = _mm_andnot_si128(emm2, p4i_4);
- emm0 = _mm_slli_epi32(emm0, 29);
- /* get the polynom selection mask */
- emm2 = _mm_and_si128(emm2, p4i_2);
- emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
-
- Packet4f sign_bit = _mm_castsi128_ps(emm0);
- Packet4f poly_mask = _mm_castsi128_ps(emm2);
-
- /* The magic pass: "Extended precision modular arithmetic"
- x = ((x - y * DP1) - y * DP2) - y * DP3; */
- xmm1 = pmul(y, p4f_minus_cephes_DP1);
- xmm2 = pmul(y, p4f_minus_cephes_DP2);
- xmm3 = pmul(y, p4f_minus_cephes_DP3);
- x = padd(x, xmm1);
- x = padd(x, xmm2);
- x = padd(x, xmm3);
-
- /* Evaluate the first polynom (0 <= x <= Pi/4) */
- y = p4f_coscof_p0;
- Packet4f z = pmul(x,x);
-
- y = pmadd(y,z,p4f_coscof_p1);
- y = pmadd(y,z,p4f_coscof_p2);
- y = pmul(y, z);
- y = pmul(y, z);
- Packet4f tmp = _mm_mul_ps(z, p4f_half);
- y = psub(y, tmp);
- y = padd(y, p4f_1);
-
- /* Evaluate the second polynom (Pi/4 <= x <= 0) */
- Packet4f y2 = p4f_sincof_p0;
- y2 = pmadd(y2, z, p4f_sincof_p1);
- y2 = pmadd(y2, z, p4f_sincof_p2);
- y2 = pmul(y2, z);
- y2 = pmadd(y2, x, x);
-
- /* select the correct result from the two polynoms */
- y2 = _mm_and_ps(poly_mask, y2);
- y = _mm_andnot_ps(poly_mask, y);
- y = _mm_or_ps(y,y2);
-
- /* update the sign */
- return _mm_xor_ps(y, sign_bit);
+ return pcos_float(_x);
}
#if EIGEN_FAST_MATH
@@ -455,17 +86,17 @@ Packet4f pcos<Packet4f>(const Packet4f& _x)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& _x)
{
- Packet4f half = pmul(_x, pset1<Packet4f>(.5f));
- Packet4f denormal_mask = _mm_and_ps(
- _mm_cmpge_ps(_x, _mm_setzero_ps()),
- _mm_cmplt_ps(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())));
+ Packet4f minus_half_x = pmul(_x, pset1<Packet4f>(-0.5f));
+ Packet4f denormal_mask = pandnot(
+ pcmp_lt(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())),
+ pcmp_lt(_x, pzero(_x)));
// Compute approximate reciprocal sqrt.
Packet4f x = _mm_rsqrt_ps(_x);
// Do a single step of Newton's iteration.
- x = pmul(x, psub(pset1<Packet4f>(1.5f), pmul(half, pmul(x,x))));
+ x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1<Packet4f>(1.5f)));
// Flush results for denormals to zero.
- return _mm_andnot_ps(denormal_mask, pmul(_x,x));
+ return pandnot(pmul(_x,x), denormal_mask);
}
#else
@@ -478,41 +109,48 @@ Packet4f psqrt<Packet4f>(const Packet4f& x) { return _mm_sqrt_ps(x); }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); }
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet16b psqrt<Packet16b>(const Packet16b& x) { return x; }
+
#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& _x) {
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(nan, 0x7fc00000);
_EIGEN_DECLARE_CONST_Packet4f(one_point_five, 1.5f);
_EIGEN_DECLARE_CONST_Packet4f(minus_half, -0.5f);
- _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000);
+ _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000u);
+ _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000u);
Packet4f neg_half = pmul(_x, p4f_minus_half);
- // select only the inverse sqrt of positive normal inputs (denormals are
- // flushed to zero and cause infs as well).
- Packet4f le_zero_mask = _mm_cmple_ps(_x, p4f_flt_min);
- Packet4f x = _mm_andnot_ps(le_zero_mask, _mm_rsqrt_ps(_x));
-
- // Fill in NaNs and Infs for the negative/zero entries.
- Packet4f neg_mask = _mm_cmplt_ps(_x, _mm_setzero_ps());
- Packet4f zero_mask = _mm_andnot_ps(neg_mask, le_zero_mask);
- Packet4f infs_and_nans = _mm_or_ps(_mm_and_ps(neg_mask, p4f_nan),
- _mm_and_ps(zero_mask, p4f_inf));
-
- // Do a single step of Newton's iteration.
- x = pmul(x, pmadd(neg_half, pmul(x, x), p4f_one_point_five));
-
- // Insert NaNs and Infs in all the right places.
- return _mm_or_ps(x, infs_and_nans);
+ // Identity infinite, zero, negative and denormal arguments.
+ Packet4f lt_min_mask = _mm_cmplt_ps(_x, p4f_flt_min);
+ Packet4f inf_mask = _mm_cmpeq_ps(_x, p4f_inf);
+ Packet4f not_normal_finite_mask = _mm_or_ps(lt_min_mask, inf_mask);
+
+ // Compute an approximate result using the rsqrt intrinsic.
+ Packet4f y_approx = _mm_rsqrt_ps(_x);
+
+ // Do a single step of Newton-Raphson iteration to improve the approximation.
+ // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
+ // It is essential to evaluate the inner term like this because forming
+ // y_n^2 may over- or underflow.
+ Packet4f y_newton = pmul(
+ y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p4f_one_point_five));
+
+ // Select the result of the Newton-Raphson step for positive normal arguments.
+ // For other arguments, choose the output of the intrinsic. This will
+ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
+ // x is zero or a positive denormalized float (equivalent to flushing positive
+ // denormalized inputs to zero).
+ return pselect<Packet4f>(not_normal_finite_mask, y_approx, y_newton);
}
#else
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
+ // Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
@@ -520,7 +158,6 @@ Packet4f prsqrt<Packet4f>(const Packet4f& x) {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}
@@ -548,7 +185,7 @@ double sqrt(const double &x)
{
#if EIGEN_COMP_GNUC_STRICT
// This works around a GCC bug generating poor code for _mm_sqrt_pd
- // See https://bitbucket.org/eigen/eigen/commits/14f468dba4d350d7c19c9b93072e19f7b3df563b
+ // See https://gitlab.com/libeigen/eigen/commit/8dca9f97e38970
return internal::pfirst(internal::Packet2d(__builtin_ia32_sqrtsd(_mm_set_sd(x))));
#else
return internal::pfirst(internal::Packet2d(_mm_sqrt_pd(_mm_set_sd(x))));
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 3832de147..db102c73a 100755
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -18,63 +18,93 @@ namespace internal {
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
-#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
+#if !defined(EIGEN_VECTORIZE_AVX) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS)
+// 32 bits => 8 registers
+// 64 bits => 16 registers
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
#endif
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD 1
+#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
#endif
-#if (defined EIGEN_VECTORIZE_AVX) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_MINGW) && (__GXX_ABI_VERSION < 1004)
+#if ((defined EIGEN_VECTORIZE_AVX) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_MINGW) && (__GXX_ABI_VERSION < 1004)) || EIGEN_OS_QNX
// With GCC's default ABI version, a __m128 or __m256 are the same types and therefore we cannot
// have overloads for both types without linking error.
// One solution is to increase ABI version using -fabi-version=4 (or greater).
// Otherwise, we workaround this inconvenience by wrapping 128bit types into the following helper
// structure:
-template<typename T>
-struct eigen_packet_wrapper
-{
- EIGEN_ALWAYS_INLINE operator T&() { return m_val; }
- EIGEN_ALWAYS_INLINE operator const T&() const { return m_val; }
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper() {}
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper(const T &v) : m_val(v) {}
- EIGEN_ALWAYS_INLINE eigen_packet_wrapper& operator=(const T &v) {
- m_val = v;
- return *this;
- }
-
- T m_val;
-};
typedef eigen_packet_wrapper<__m128> Packet4f;
-typedef eigen_packet_wrapper<__m128i> Packet4i;
typedef eigen_packet_wrapper<__m128d> Packet2d;
#else
typedef __m128 Packet4f;
-typedef __m128i Packet4i;
typedef __m128d Packet2d;
#endif
+typedef eigen_packet_wrapper<__m128i, 0> Packet4i;
+typedef eigen_packet_wrapper<__m128i, 1> Packet16b;
+
template<> struct is_arithmetic<__m128> { enum { value = true }; };
template<> struct is_arithmetic<__m128i> { enum { value = true }; };
template<> struct is_arithmetic<__m128d> { enum { value = true }; };
+template<> struct is_arithmetic<Packet4i> { enum { value = true }; };
+template<> struct is_arithmetic<Packet16b> { enum { value = true }; };
+
+template<int p, int q, int r, int s>
+struct shuffle_mask{
+ enum { mask = (s)<<6|(r)<<4|(q)<<2|(p) };
+};
+// TODO: change the implementation of all swizzle* ops from macro to template,
#define vec4f_swizzle1(v,p,q,r,s) \
- (_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), ((s)<<6|(r)<<4|(q)<<2|(p)))))
+ Packet4f(_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), (shuffle_mask<p,q,r,s>::mask))))
#define vec4i_swizzle1(v,p,q,r,s) \
- (_mm_shuffle_epi32( v, ((s)<<6|(r)<<4|(q)<<2|(p))))
+ Packet4i(_mm_shuffle_epi32( v, (shuffle_mask<p,q,r,s>::mask)))
#define vec2d_swizzle1(v,p,q) \
- (_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), ((q*2+1)<<6|(q*2)<<4|(p*2+1)<<2|(p*2)))))
-
+ Packet2d(_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), (shuffle_mask<2*p,2*p+1,2*q,2*q+1>::mask))))
+
#define vec4f_swizzle2(a,b,p,q,r,s) \
- (_mm_shuffle_ps( (a), (b), ((s)<<6|(r)<<4|(q)<<2|(p))))
+ Packet4f(_mm_shuffle_ps( (a), (b), (shuffle_mask<p,q,r,s>::mask)))
#define vec4i_swizzle2(a,b,p,q,r,s) \
- (_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), ((s)<<6|(r)<<4|(q)<<2|(p))))))
+ Packet4i(_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), (shuffle_mask<p,q,r,s>::mask)))))
+
+EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_movelh_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_movehl_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_unpacklo_ps(a,b));
+}
+EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b)
+{
+ return Packet4f(_mm_unpackhi_ps(a,b));
+}
+#define vec4f_duplane(a,p) \
+ vec4f_swizzle2(a,a,p,p,p,p)
+
+#define vec2d_swizzle2(a,b,mask) \
+ Packet2d(_mm_shuffle_pd(a,b,mask))
+
+EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a, const Packet2d& b)
+{
+ return Packet2d(_mm_unpacklo_pd(a,b));
+}
+EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a, const Packet2d& b)
+{
+ return Packet2d(_mm_unpackhi_pd(a,b));
+}
+#define vec2d_duplane(a,p) \
+ vec2d_swizzle2(a,a,(p<<1)|p)
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
const Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -83,7 +113,7 @@ template<> struct is_arithmetic<__m128d> { enum { value = true }; };
const Packet2d p2d_##NAME = pset1<Packet2d>(X)
#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \
- const Packet4f p4f_##NAME = _mm_castsi128_ps(pset1<Packet4i>(X))
+ const Packet4f p4f_##NAME = pset1frombits<Packet4f>(X)
#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \
const Packet4i p4i_##NAME = pset1<Packet4i>(X)
@@ -92,36 +122,41 @@ template<> struct is_arithmetic<__m128d> { enum { value = true }; };
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
// to leverage AVX instructions.
#ifndef EIGEN_VECTORIZE_AVX
-template<> struct packet_traits<float> : default_packet_traits
-{
+template <>
+struct packet_traits<float> : default_packet_traits {
typedef Packet4f type;
typedef Packet4f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
+ size = 4,
HasHalfPacket = 0,
- HasDiv = 1,
- HasSin = EIGEN_FAST_MATH,
- HasCos = EIGEN_FAST_MATH,
- HasLog = 1,
- HasExp = 1,
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasNdtri = 1,
+ HasExp = 1,
+ HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasTanh = EIGEN_FAST_MATH,
- HasBlend = 1
-
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 1,
+ HasCeil = 1,
+ HasFloor = 1,
#ifdef EIGEN_VECTORIZE_SSE4_1
- ,
HasRound = 1,
- HasFloor = 1,
- HasCeil = 1
#endif
+ HasRint = 1
};
};
-template<> struct packet_traits<double> : default_packet_traits
-{
+template <>
+struct packet_traits<double> : default_packet_traits {
typedef Packet2d type;
typedef Packet2d half;
enum {
@@ -130,18 +165,19 @@ template<> struct packet_traits<double> : default_packet_traits
size=2,
HasHalfPacket = 0,
+ HasCmp = 1,
HasDiv = 1,
+ HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasBlend = 1
-
+ HasBlend = 1,
+ HasFloor = 1,
+ HasCeil = 1,
#ifdef EIGEN_VECTORIZE_SSE4_1
- ,
HasRound = 1,
- HasFloor = 1,
- HasCeil = 1
#endif
+ HasRint = 1
};
};
#endif
@@ -154,13 +190,56 @@ template<> struct packet_traits<int> : default_packet_traits
AlignedOnScalar = 1,
size=4,
+ HasShift = 1,
HasBlend = 1
};
};
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
-template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct packet_traits<bool> : default_packet_traits
+{
+ typedef Packet16b type;
+ typedef Packet16b half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ HasHalfPacket = 0,
+ size=16,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 0,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasConj = 0,
+ HasSqrt = 1
+ };
+};
+
+template<> struct unpacket_traits<Packet4f> {
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet4i integer_packet;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet2d> {
+ typedef double type;
+ typedef Packet2d half;
+ enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet4i> {
+ typedef int type;
+ typedef Packet4i half;
+ enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet16b> {
+ typedef bool type;
+ typedef Packet16b half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
#ifndef EIGEN_VECTORIZE_AVX
template<> struct scalar_div_cost<float,true> { enum { value = 7 }; };
@@ -179,6 +258,18 @@ template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { re
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) { return _mm_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) { return _mm_set1_epi32(from); }
#endif
+template<> EIGEN_STRONG_INLINE Packet16b pset1<Packet16b>(const bool& from) { return _mm_set1_epi8(static_cast<char>(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) { return _mm_castsi128_ps(pset1<Packet4i>(from)); }
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return _mm_castsi128_ps(_mm_set_epi32(0, -1, 0, -1)); }
+template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return _mm_set_epi32(0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return _mm_castsi128_pd(_mm_set_epi32(0, 0, -1, -1)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); }
+template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); }
+template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); }
// GCC generates a shufps instruction for _mm_set1_ps/_mm_load1_ps instead of the more efficient pshufd instruction.
// However, using inrinsics for pset1 makes gcc to generate crappy code in some cases (see bug 203)
@@ -190,7 +281,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pload1<Packet4f>(const float *from) {
return vec4f_swizzle1(_mm_load_ss(from),0,0,0,0);
}
#endif
-
+
template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return _mm_add_ps(pset1<Packet4f>(a), _mm_set_ps(3,2,1,0)); }
template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) { return _mm_add_pd(pset1<Packet2d>(a),_mm_set_pd(1,0)); }
template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return _mm_add_epi32(pset1<Packet4i>(a),_mm_set_epi32(3,2,1,0)); }
@@ -199,9 +290,34 @@ template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const
template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_add_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b psub<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b);
+template<> EIGEN_STRONG_INLINE Packet4f paddsub<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+#ifdef EIGEN_VECTORIZE_SSE3
+ return _mm_addsub_ps(a,b);
+#else
+ const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x0,0x80000000,0x0));
+ return padd(a, pxor(mask, b));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& , const Packet2d& );
+template<> EIGEN_STRONG_INLINE Packet2d paddsub<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+#ifdef EIGEN_VECTORIZE_SSE3
+ return _mm_addsub_pd(a,b);
+#else
+ const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x80000000,0x0,0x0));
+ return padd(a, pxor(mask, b));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
{
@@ -218,6 +334,11 @@ template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a)
return psub(Packet4i(_mm_setr_epi32(0,0,0,0)), a);
}
+template<> EIGEN_STRONG_INLINE Packet16b pnegate(const Packet16b& a)
+{
+ return psub(pset1<Packet16b>(false), a);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
@@ -240,18 +361,126 @@ template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const
#endif
}
+template<> EIGEN_STRONG_INLINE Packet16b pmul<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_div_pd(a,b); }
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd(pmul(a,b), c); }
-#ifdef __FMA__
+#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
#endif
-template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_min_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_min_pd(a,b); }
+#ifdef EIGEN_VECTORIZE_SSE4_1
+template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
+ return _mm_blendv_ps(b,a,mask);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b) {
+ return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b),_mm_castsi128_ps(a),_mm_castsi128_ps(mask)));
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); }
+
+template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) {
+ return _mm_blendv_epi8(b,a,mask);
+}
+#else
+template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) {
+ Packet16b a_part = _mm_and_si128(mask, a);
+ Packet16b b_part = _mm_andnot_si128(mask, b);
+ return _mm_or_si128(a_part, b_part);
+}
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4i ptrue<Packet4i>(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); }
+template<> EIGEN_STRONG_INLINE Packet16b ptrue<Packet16b>(const Packet16b& a) { return _mm_cmpeq_epi8(a, a); }
+template<> EIGEN_STRONG_INLINE Packet4f
+ptrue<Packet4f>(const Packet4f& a) {
+ Packet4i b = _mm_castps_si128(a);
+ return _mm_castsi128_ps(_mm_cmpeq_epi32(b, b));
+}
+template<> EIGEN_STRONG_INLINE Packet2d
+ptrue<Packet2d>(const Packet2d& a) {
+ Packet4i b = _mm_castpd_si128(a);
+ return _mm_castsi128_pd(_mm_cmpeq_epi32(b, b));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pand<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b por<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pxor<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(b,a); }
+template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(b,a); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(b,a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16b pcmp_eq(const Packet16b& a, const Packet16b& b) { return _mm_cmpeq_epi8(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return por(pcmp_lt(a,b), pcmp_eq(a,b)); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_min_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet4f res;
+ asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet4f res = b;
+ asm("minps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm_min_ps(b, a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_min_pd, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet2d res;
+ asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet2d res = b;
+ asm("minpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::min.
+ return _mm_min_pd(b, a);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b)
{
#ifdef EIGEN_VECTORIZE_SSE4_1
@@ -263,8 +492,45 @@ template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const
#endif
}
-template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_max_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_max_pd(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_max_ps, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet4f res;
+ asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet4f res = b;
+ asm("maxps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm_max_ps(b, a);
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) {
+#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63
+ // There appears to be a bug in GCC, by which the optimizer may
+ // flip the argument order in calls to _mm_max_pd, so we have to
+ // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+ // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+ #ifdef EIGEN_VECTORIZE_AVX
+ Packet2d res;
+ asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
+ #else
+ Packet2d res = b;
+ asm("maxpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
+ #endif
+ return res;
+#else
+ // Arguments are reversed to match NaN propagation behavior of std::max.
+ return _mm_max_pd(b, a);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b)
{
#ifdef EIGEN_VECTORIZE_SSE4_1
@@ -276,36 +542,180 @@ template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const
#endif
}
+template <typename Packet, typename Op>
+EIGEN_STRONG_INLINE Packet pminmax_propagate_numbers(const Packet& a, const Packet& b, Op op) {
+ // In this implementation, we take advantage of the fact that pmin/pmax for SSE
+ // always return a if either a or b is NaN.
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet m = op(a, b);
+ return pselect<Packet>(not_nan_mask_a, m, b);
+}
+
+template <typename Packet, typename Op>
+EIGEN_STRONG_INLINE Packet pminmax_propagate_nan(const Packet& a, const Packet& b, Op op) {
+ // In this implementation, we take advantage of the fact that pmin/pmax for SSE
+ // always return a if either a or b is NaN.
+ Packet not_nan_mask_a = pcmp_eq(a, a);
+ Packet m = op(b, a);
+ return pselect<Packet>(not_nan_mask_a, m, a);
+}
+
+// Add specializations for min/max with prescribed NaN progation.
+template<>
+EIGEN_STRONG_INLINE Packet4f pmin<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmin<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_numbers(a, b, pmin<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmax<PropagateNumbers, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmax<PropagateNumbers, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_numbers(a, b, pmax<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmin<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmin<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_nan(a, b, pmin<Packet2d>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet4f pmax<PropagateNaN, Packet4f>(const Packet4f& a, const Packet4f& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet4f>);
+}
+template<>
+EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) {
+ return pminmax_propagate_nan(a, b, pmax<Packet2d>);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) { return _mm_srai_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right (const Packet4i& a) { return _mm_srli_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left (const Packet4i& a) { return _mm_slli_epi32(a,N); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a)
+{
+ const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF));
+ return _mm_and_ps(a,mask);
+}
+template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a)
+{
+ const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF));
+ return _mm_and_pd(a,mask);
+}
+template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
+{
+ #ifdef EIGEN_VECTORIZE_SSSE3
+ return _mm_abs_epi32(a);
+ #else
+ Packet4i aux = _mm_srai_epi32(a,31);
+ return _mm_sub_epi32(_mm_xor_si128(a,aux),aux);
+ #endif
+}
+
#ifdef EIGEN_VECTORIZE_SSE4_1
-template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) { return _mm_round_ps(a, 0); }
-template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return _mm_round_pd(a, 0); }
+template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
+{
+ // Unfortunatly _mm_round_ps doesn't have a rounding mode to implement numext::round.
+ const Packet4f mask = pset1frombits<Packet4f>(0x80000000u);
+ const Packet4f prev0dot5 = pset1frombits<Packet4f>(0x3EFFFFFFu);
+ return _mm_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a)
+{
+ const Packet2d mask = _mm_castsi128_pd(_mm_set_epi64x(0x8000000000000000ull, 0x8000000000000000ull));
+ const Packet2d prev0dot5 = _mm_castsi128_pd(_mm_set_epi64x(0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull));
+ return _mm_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a) { return _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a) { return _mm_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) { return _mm_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return _mm_ceil_pd(a); }
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) { return _mm_floor_ps(a); }
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return _mm_floor_pd(a); }
-#endif
+#else
+template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet4f limit = pset1<Packet4f>(static_cast<float>(1<<23));
+ const Packet4f abs_a = pabs(a);
+ Packet4f r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
-template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) {
+ // Adds and subtracts signum(a) * 2^52 to force rounding.
+ const Packet2d limit = pset1<Packet2d>(static_cast<double>(1ull<<52));
+ const Packet2d abs_a = pabs(a);
+ Packet2d r = padd(abs_a, limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(r);
+ r = psub(r, limit);
+ // If greater than limit, simply return a. Otherwise, account for sign.
+ r = pselect(pcmp_lt(abs_a, limit),
+ pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
+ return r;
+}
-template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If greater, subtract one.
+ Packet4f mask = _mm_cmpgt_ps(tmp, a);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
-template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ Packet2d tmp = print<Packet2d>(a);
+ // If greater, subtract one.
+ Packet2d mask = _mm_cmpgt_pd(tmp, a);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
+}
-template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ Packet4f tmp = print<Packet4f>(a);
+ // If smaller, add one.
+ Packet4f mask = _mm_cmplt_ps(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ Packet2d tmp = print<Packet2d>(a);
+ // If smaller, add one.
+ Packet2d mask = _mm_cmplt_pd(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+#endif
template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet16b pload<Packet16b>(const bool* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
#if EIGEN_COMP_MSVC
template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) {
@@ -340,6 +750,10 @@ template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from)
EIGEN_DEBUG_UNALIGNED_LOAD
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
}
+template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
@@ -355,13 +769,32 @@ template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int* from)
return vec4i_swizzle1(tmp, 0, 0, 1, 1);
}
+// Loads 8 bools from memory and returns the packet
+// {b0, b0, b1, b1, b2, b2, b3, b3, b4, b4, b5, b5, b6, b6, b7, b7}
+template<> EIGEN_STRONG_INLINE Packet16b ploaddup<Packet16b>(const bool* from)
+{
+ __m128i tmp = _mm_castpd_si128(pload1<Packet2d>(reinterpret_cast<const double*>(from)));
+ return _mm_unpacklo_epi8(tmp, tmp);
+}
+
+// Loads 4 bools from memory and returns the packet
+// {b0, b0 b0, b0, b1, b1, b1, b1, b2, b2, b2, b2, b3, b3, b3, b3}
+template<> EIGEN_STRONG_INLINE Packet16b
+ploadquad<Packet16b>(const bool* from) {
+ __m128i tmp = _mm_castps_si128(pload1<Packet4f>(reinterpret_cast<const float*>(from)));
+ tmp = _mm_unpacklo_epi8(tmp, tmp);
+ return _mm_unpacklo_epi16(tmp, tmp);
+}
+
template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstore<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
@@ -374,7 +807,15 @@ template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const dou
template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int, Packet4i>(const int* from, Index stride)
{
return _mm_set_epi32(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
- }
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet16b pgather<bool, Packet16b>(const bool* from, Index stride)
+{
+ return _mm_set_epi8(from[15*stride], from[14*stride], from[13*stride], from[12*stride],
+ from[11*stride], from[10*stride], from[9*stride], from[8*stride],
+ from[7*stride], from[6*stride], from[5*stride], from[4*stride],
+ from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
+}
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
@@ -395,6 +836,14 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<int, Packet4i>(int* to, const
to[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2));
to[stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3));
}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<bool, Packet16b>(bool* to, const Packet16b& from, Index stride)
+{
+ to[4*stride*0] = _mm_cvtsi128_si32(from);
+ to[4*stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 1));
+ to[4*stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2));
+ to[4*stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3));
+}
+
// some compilers might be tempted to perform multiple moves instead of using a vector path.
template<> EIGEN_STRONG_INLINE void pstore1<Packet4f>(float* to, const float& a)
@@ -409,10 +858,16 @@ template<> EIGEN_STRONG_INLINE void pstore1<Packet2d>(double* to, const double&
pstore(to, Packet2d(vec2d_swizzle1(pa,0,0)));
}
+#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900
+typedef const void * SsePrefetchPtrType;
+#else
+typedef const char * SsePrefetchPtrType;
+#endif
+
#ifndef EIGEN_VECTORIZE_AVX
-template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
-template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
#endif
#if EIGEN_COMP_MSVC_STRICT && EIGEN_OS_WIN64
@@ -431,32 +886,62 @@ template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { retu
template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return _mm_cvtsd_f64(a); }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { return _mm_cvtsi128_si32(a); }
#endif
+template<> EIGEN_STRONG_INLINE bool pfirst<Packet16b>(const Packet16b& a) { int x = _mm_cvtsi128_si32(a); return static_cast<bool>(x & 1); }
-template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
-{ return _mm_shuffle_ps(a,a,0x1B); }
-template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
-{ return _mm_shuffle_pd(a,a,0x1); }
-template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
-{ return _mm_shuffle_epi32(a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) { return _mm_shuffle_ps(a,a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return _mm_shuffle_pd(a,a,0x1); }
+template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) { return _mm_shuffle_epi32(a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet16b preverse(const Packet16b& a) {
+#ifdef EIGEN_VECTORIZE_SSSE3
+ __m128i mask = _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
+ return _mm_shuffle_epi8(a, mask);
+#else
+ Packet16b tmp = _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 1, 2, 3));
+ tmp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(tmp, _MM_SHUFFLE(2, 3, 0, 1)), _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_or_si128(_mm_slli_epi16(tmp, 8), _mm_srli_epi16(tmp, 8));
+#endif
+}
-template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a)
-{
- const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF));
- return _mm_and_ps(a,mask);
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
+ return pfrexp_generic(a,exponent);
}
-template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a)
-{
- const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF));
- return _mm_and_pd(a,mask);
+
+// Extract exponent without existence of Packet2l.
+template<>
+EIGEN_STRONG_INLINE
+Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
+ const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
+ __m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
+ return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3));
}
-template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
-{
- #ifdef EIGEN_VECTORIZE_SSSE3
- return _mm_abs_epi32(a);
- #else
- Packet4i aux = _mm_srai_epi32(a,31);
- return _mm_sub_epi32(_mm_xor_si128(a,aux),aux);
- #endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
+ return pldexp_generic(a,exponent);
+}
+
+// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well
+// supported by SSE, and has more range than is needed for exponents.
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet2d max_exponent = pset1<Packet2d>(2099.0);
+ const Packet2d e = pmin(pmax(exponent, pnegate(max_exponent)), max_exponent);
+
+ // Convert e to integer and swizzle to low-order bits.
+ const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3);
+
+ // Split 2^e into four factors and multiply:
+ const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023);
+ Packet4i b = parithmetic_shift_right<2>(ei); // floor(e/4)
+ Packet2d c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^b
+ Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(ei, b), b), b); // e - 3b
+ c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^(e - 3b)
+ out = pmul(out, c); // a * 2^e
+ return out;
}
// with AVX, the default implementations based on pload1 are faster
@@ -499,38 +984,6 @@ EIGEN_STRONG_INLINE void punpackp(Packet4f* vecs)
vecs[0] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0x00));
}
-#ifdef EIGEN_VECTORIZE_SSE3
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- return _mm_hadd_ps(_mm_hadd_ps(vecs[0], vecs[1]),_mm_hadd_ps(vecs[2], vecs[3]));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- return _mm_hadd_pd(vecs[0], vecs[1]);
-}
-
-#else
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- Packet4f tmp0, tmp1, tmp2;
- tmp0 = _mm_unpacklo_ps(vecs[0], vecs[1]);
- tmp1 = _mm_unpackhi_ps(vecs[0], vecs[1]);
- tmp2 = _mm_unpackhi_ps(vecs[2], vecs[3]);
- tmp0 = _mm_add_ps(tmp0, tmp1);
- tmp1 = _mm_unpacklo_ps(vecs[2], vecs[3]);
- tmp1 = _mm_add_ps(tmp1, tmp2);
- tmp2 = _mm_movehl_ps(tmp1, tmp0);
- tmp0 = _mm_movelh_ps(tmp0, tmp1);
- return _mm_add_ps(tmp0, tmp2);
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- return _mm_add_pd(_mm_unpacklo_pd(vecs[0], vecs[1]), _mm_unpackhi_pd(vecs[0], vecs[1]));
-}
-#endif // SSE3
-
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
@@ -556,38 +1009,28 @@ template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
}
#ifdef EIGEN_VECTORIZE_SSSE3
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
-{
- return _mm_hadd_epi32(_mm_hadd_epi32(vecs[0], vecs[1]),_mm_hadd_epi32(vecs[2], vecs[3]));
-}
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i tmp0 = _mm_hadd_epi32(a,a);
return pfirst<Packet4i>(_mm_hadd_epi32(tmp0,tmp0));
}
+
#else
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a,a));
return pfirst(tmp) + pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1));
}
+#endif
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
-{
- Packet4i tmp0, tmp1, tmp2;
- tmp0 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
- tmp1 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
- tmp2 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
- tmp0 = _mm_add_epi32(tmp0, tmp1);
- tmp1 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
- tmp1 = _mm_add_epi32(tmp1, tmp2);
- tmp2 = _mm_unpacklo_epi64(tmp0, tmp1);
- tmp0 = _mm_unpackhi_epi64(tmp0, tmp1);
- return _mm_add_epi32(tmp0, tmp2);
+template<> EIGEN_STRONG_INLINE bool predux<Packet16b>(const Packet16b& a) {
+ Packet4i tmp = _mm_or_si128(a, _mm_unpackhi_epi64(a,a));
+ return (pfirst(tmp) != 0) || (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) != 0);
}
-#endif
+
// Other reduction functions:
+
// mul
template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
{
@@ -605,7 +1048,13 @@ template<> EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a)
// TODO try to call _mm_mul_epu32 directly
EIGEN_ALIGN16 int aux[4];
pstore(aux, a);
- return (aux[0] * aux[1]) * (aux[2] * aux[3]);;
+ return (aux[0] * aux[1]) * (aux[2] * aux[3]);
+}
+
+template<> EIGEN_STRONG_INLINE bool predux_mul<Packet16b>(const Packet16b& a) {
+ Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a,a));
+ return ((pfirst<Packet4i>(tmp) == 0x01010101) &&
+ (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) == 0x01010101));
}
// min
@@ -660,113 +1109,16 @@ template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
#endif // EIGEN_VECTORIZE_SSE4_1
}
-#if EIGEN_COMP_GNUC
-// template <> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
+// not needed yet
+// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet4f& x)
// {
-// Packet4f res = b;
-// asm("mulps %[a], %[b] \n\taddps %[c], %[b]" : [b] "+x" (res) : [a] "x" (a), [c] "x" (c));
-// return res;
+// return _mm_movemask_ps(x) == 0xF;
// }
-// EIGEN_STRONG_INLINE Packet4i _mm_alignr_epi8(const Packet4i& a, const Packet4i& b, const int i)
-// {
-// Packet4i res = a;
-// asm("palignr %[i], %[a], %[b] " : [b] "+x" (res) : [a] "x" (a), [i] "i" (i));
-// return res;
-// }
-#endif
-
-#ifdef EIGEN_VECTORIZE_SSSE3
-// SSSE3 versions
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
-{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
- if (Offset!=0)
- first = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(second), _mm_castps_si128(first), Offset*4));
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
-{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
- if (Offset!=0)
- first = _mm_alignr_epi8(second,first, Offset*4);
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset==1)
- first = _mm_castsi128_pd(_mm_alignr_epi8(_mm_castpd_si128(second), _mm_castpd_si128(first), 8));
- }
-};
-#else
-// SSE2 versions
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
-{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
- if (Offset==1)
- {
- first = _mm_move_ss(first,second);
- first = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(first),0x39));
- }
- else if (Offset==2)
- {
- first = _mm_movehl_ps(first,first);
- first = _mm_movelh_ps(first,second);
- }
- else if (Offset==3)
- {
- first = _mm_move_ss(first,second);
- first = _mm_shuffle_ps(first,second,0x93);
- }
- }
-};
-
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
-{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
- if (Offset==1)
- {
- first = _mm_castps_si128(_mm_move_ss(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- first = _mm_shuffle_epi32(first,0x39);
- }
- else if (Offset==2)
- {
- first = _mm_castps_si128(_mm_movehl_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(first)));
- first = _mm_castps_si128(_mm_movelh_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- }
- else if (Offset==3)
- {
- first = _mm_castps_si128(_mm_move_ss(_mm_castsi128_ps(first),_mm_castsi128_ps(second)));
- first = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(first),_mm_castsi128_ps(second),0x93));
- }
- }
-};
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset==1)
- {
- first = _mm_castps_pd(_mm_movehl_ps(_mm_castpd_ps(first),_mm_castpd_ps(first)));
- first = _mm_castps_pd(_mm_movelh_ps(_mm_castpd_ps(first),_mm_castpd_ps(second)));
- }
- }
-};
-#endif
+ return _mm_movemask_ps(x) != 0x0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
@@ -793,6 +1145,100 @@ ptranspose(PacketBlock<Packet4i,4>& kernel) {
kernel.packet[3] = _mm_unpackhi_epi64(T2, T3);
}
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16b,4>& kernel) {
+ __m128i T0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]);
+ __m128i T1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]);
+ __m128i T2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]);
+ __m128i T3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]);
+ kernel.packet[0] = _mm_unpacklo_epi16(T0, T2);
+ kernel.packet[1] = _mm_unpackhi_epi16(T0, T2);
+ kernel.packet[2] = _mm_unpacklo_epi16(T1, T3);
+ kernel.packet[3] = _mm_unpackhi_epi16(T1, T3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16b,16>& kernel) {
+ // If we number the elements in the input thus:
+ // kernel.packet[ 0] = {00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 0a, 0b, 0c, 0d, 0e, 0f}
+ // kernel.packet[ 1] = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1a, 1b, 1c, 1d, 1e, 1f}
+ // ...
+ // kernel.packet[15] = {f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, fa, fb, fc, fd, fe, ff},
+ //
+ // the desired output is:
+ // kernel.packet[ 0] = {00, 10, 20, 30, 40, 50, 60, 70, 80, 90, a0, b0, c0, d0, e0, f0}
+ // kernel.packet[ 1] = {01, 11, 21, 31, 41, 51, 61, 71, 81, 91, a1, b1, c1, d1, e1, f1}
+ // ...
+ // kernel.packet[15] = {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, af, bf, cf, df, ef, ff},
+ __m128i t0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]); // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17
+ __m128i t1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]); // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f
+ __m128i t2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]); // 20 30 21 31 22 32 ... 27 37
+ __m128i t3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]); // 28 38 29 39 2a 3a ... 2f 3f
+ __m128i t4 = _mm_unpacklo_epi8(kernel.packet[4], kernel.packet[5]); // 40 50 41 51 42 52 47 57
+ __m128i t5 = _mm_unpackhi_epi8(kernel.packet[4], kernel.packet[5]); // 48 58 49 59 4a 5a
+ __m128i t6 = _mm_unpacklo_epi8(kernel.packet[6], kernel.packet[7]);
+ __m128i t7 = _mm_unpackhi_epi8(kernel.packet[6], kernel.packet[7]);
+ __m128i t8 = _mm_unpacklo_epi8(kernel.packet[8], kernel.packet[9]);
+ __m128i t9 = _mm_unpackhi_epi8(kernel.packet[8], kernel.packet[9]);
+ __m128i ta = _mm_unpacklo_epi8(kernel.packet[10], kernel.packet[11]);
+ __m128i tb = _mm_unpackhi_epi8(kernel.packet[10], kernel.packet[11]);
+ __m128i tc = _mm_unpacklo_epi8(kernel.packet[12], kernel.packet[13]);
+ __m128i td = _mm_unpackhi_epi8(kernel.packet[12], kernel.packet[13]);
+ __m128i te = _mm_unpacklo_epi8(kernel.packet[14], kernel.packet[15]);
+ __m128i tf = _mm_unpackhi_epi8(kernel.packet[14], kernel.packet[15]);
+
+ __m128i s0 = _mm_unpacklo_epi16(t0, t2); // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33
+ __m128i s1 = _mm_unpackhi_epi16(t0, t2); // 04 14 24 34
+ __m128i s2 = _mm_unpacklo_epi16(t1, t3); // 08 18 28 38 ...
+ __m128i s3 = _mm_unpackhi_epi16(t1, t3); // 0c 1c 2c 3c ...
+ __m128i s4 = _mm_unpacklo_epi16(t4, t6); // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73
+ __m128i s5 = _mm_unpackhi_epi16(t4, t6); // 44 54 64 74 ...
+ __m128i s6 = _mm_unpacklo_epi16(t5, t7);
+ __m128i s7 = _mm_unpackhi_epi16(t5, t7);
+ __m128i s8 = _mm_unpacklo_epi16(t8, ta);
+ __m128i s9 = _mm_unpackhi_epi16(t8, ta);
+ __m128i sa = _mm_unpacklo_epi16(t9, tb);
+ __m128i sb = _mm_unpackhi_epi16(t9, tb);
+ __m128i sc = _mm_unpacklo_epi16(tc, te);
+ __m128i sd = _mm_unpackhi_epi16(tc, te);
+ __m128i se = _mm_unpacklo_epi16(td, tf);
+ __m128i sf = _mm_unpackhi_epi16(td, tf);
+
+ __m128i u0 = _mm_unpacklo_epi32(s0, s4); // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71
+ __m128i u1 = _mm_unpackhi_epi32(s0, s4); // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73
+ __m128i u2 = _mm_unpacklo_epi32(s1, s5);
+ __m128i u3 = _mm_unpackhi_epi32(s1, s5);
+ __m128i u4 = _mm_unpacklo_epi32(s2, s6);
+ __m128i u5 = _mm_unpackhi_epi32(s2, s6);
+ __m128i u6 = _mm_unpacklo_epi32(s3, s7);
+ __m128i u7 = _mm_unpackhi_epi32(s3, s7);
+ __m128i u8 = _mm_unpacklo_epi32(s8, sc);
+ __m128i u9 = _mm_unpackhi_epi32(s8, sc);
+ __m128i ua = _mm_unpacklo_epi32(s9, sd);
+ __m128i ub = _mm_unpackhi_epi32(s9, sd);
+ __m128i uc = _mm_unpacklo_epi32(sa, se);
+ __m128i ud = _mm_unpackhi_epi32(sa, se);
+ __m128i ue = _mm_unpacklo_epi32(sb, sf);
+ __m128i uf = _mm_unpackhi_epi32(sb, sf);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(u0, u8);
+ kernel.packet[1] = _mm_unpackhi_epi64(u0, u8);
+ kernel.packet[2] = _mm_unpacklo_epi64(u1, u9);
+ kernel.packet[3] = _mm_unpackhi_epi64(u1, u9);
+ kernel.packet[4] = _mm_unpacklo_epi64(u2, ua);
+ kernel.packet[5] = _mm_unpackhi_epi64(u2, ua);
+ kernel.packet[6] = _mm_unpacklo_epi64(u3, ub);
+ kernel.packet[7] = _mm_unpackhi_epi64(u3, ub);
+ kernel.packet[8] = _mm_unpacklo_epi64(u4, uc);
+ kernel.packet[9] = _mm_unpackhi_epi64(u4, uc);
+ kernel.packet[10] = _mm_unpacklo_epi64(u5, ud);
+ kernel.packet[11] = _mm_unpackhi_epi64(u5, ud);
+ kernel.packet[12] = _mm_unpacklo_epi64(u6, ue);
+ kernel.packet[13] = _mm_unpackhi_epi64(u6, ue);
+ kernel.packet[14] = _mm_unpacklo_epi64(u7, uf);
+ kernel.packet[15] = _mm_unpackhi_epi64(u7, uf);
+}
+
template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
const __m128i zero = _mm_setzero_si128();
const __m128i select = _mm_set_epi32(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
@@ -824,56 +1270,236 @@ template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, cons
#endif
}
-template<> EIGEN_STRONG_INLINE Packet4f pinsertfirst(const Packet4f& a, float b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_ps(a,pset1<Packet4f>(b),1);
-#else
- return _mm_move_ss(a, _mm_load_ss(&b));
+// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
+#ifdef EIGEN_VECTORIZE_FMA
+template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
+ return ::fmaf(a,b,c);
+}
+template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
+ return ::fma(a,b,c);
+}
#endif
+
+
+// Packet math for Eigen::half
+// Disable the following code since it's broken on too many platforms / compilers.
+//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
+#if 0
+
+typedef struct {
+ __m64 x;
+} Packet4h;
+
+
+template<> struct is_arithmetic<Packet4h> { enum { value = true }; };
+
+template <>
+struct packet_traits<Eigen::half> : default_packet_traits {
+ typedef Packet4h type;
+ // There is no half-size packet for Packet4h.
+ typedef Packet4h half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+ HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasNegate = 0,
+ HasAbs = 0,
+ HasAbs2 = 0,
+ HasMin = 0,
+ HasMax = 0,
+ HasConj = 0,
+ HasSetLinear = 0,
+ HasSqrt = 0,
+ HasRsqrt = 0,
+ HasExp = 0,
+ HasLog = 0,
+ HasBlend = 0
+ };
+};
+
+
+template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h half; };
+
+template<> EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) {
+ Packet4h result;
+ result.x = _mm_set1_pi16(from.x);
+ return result;
}
-template<> EIGEN_STRONG_INLINE Packet2d pinsertfirst(const Packet2d& a, double b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_pd(a,pset1<Packet2d>(b),1);
-#else
- return _mm_move_sd(a, _mm_load_sd(&b));
-#endif
+template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h>(const Packet4h& from) {
+ return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm_cvtsi64_si32(from.x)));
}
-template<> EIGEN_STRONG_INLINE Packet4f pinsertlast(const Packet4f& a, float b)
-{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_ps(a,pset1<Packet4f>(b),(1<<3));
-#else
- const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x0,0x0,0x0,0xFFFFFFFF));
- return _mm_or_ps(_mm_andnot_ps(mask, a), _mm_and_ps(mask, pset1<Packet4f>(b)));
-#endif
+template<> EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4h padd<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha + hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha + hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h psub<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha - hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha - hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pmul<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha * hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha * hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
}
-template<> EIGEN_STRONG_INLINE Packet2d pinsertlast(const Packet2d& a, double b)
+template<> EIGEN_STRONG_INLINE Packet4h pdiv<Packet4h>(const Packet4h& a, const Packet4h& b) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ __int64_t b64 = _mm_cvtm64_si64(b.x);
+
+ Eigen::half h[4];
+
+ Eigen::half ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64));
+ Eigen::half hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64));
+ h[0] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 16));
+ h[1] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 32));
+ h[2] = ha / hb;
+ ha = half_impl::raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ hb = half_impl::raw_uint16_to_half(static_cast<unsigned short>(b64 >> 48));
+ h[3] = ha / hb;
+ Packet4h result;
+ result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pload<Packet4h>(const Eigen::half* from) {
+ Packet4h result;
+ result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h ploadu<Packet4h>(const Eigen::half* from) {
+ Packet4h result;
+ result.x = _mm_cvtsi64_m64(*reinterpret_cast<const __int64_t*>(from));
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h& from) {
+ __int64_t r = _mm_cvtm64_si64(from.x);
+ *(reinterpret_cast<__int64_t*>(to)) = r;
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h& from) {
+ __int64_t r = _mm_cvtm64_si64(from.x);
+ *(reinterpret_cast<__int64_t*>(to)) = r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h
+ploadquad<Packet4h>(const Eigen::half* from) {
+ return pset1<Packet4h>(*from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4h pgather<Eigen::half, Packet4h>(const Eigen::half* from, Index stride)
{
-#ifdef EIGEN_VECTORIZE_SSE4_1
- return _mm_blend_pd(a,pset1<Packet2d>(b),(1<<1));
-#else
- const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x0,0xFFFFFFFF,0xFFFFFFFF));
- return _mm_or_pd(_mm_andnot_pd(mask, a), _mm_and_pd(mask, pset1<Packet2d>(b)));
-#endif
+ Packet4h result;
+ result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
+ return result;
}
-// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
-#ifdef __FMA__
-template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
- return ::fmaf(a,b,c);
+template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h>(Eigen::half* to, const Packet4h& from, Index stride)
+{
+ __int64_t a = _mm_cvtm64_si64(from.x);
+ to[stride*0].x = static_cast<unsigned short>(a);
+ to[stride*1].x = static_cast<unsigned short>(a >> 16);
+ to[stride*2].x = static_cast<unsigned short>(a >> 32);
+ to[stride*3].x = static_cast<unsigned short>(a >> 48);
}
-template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
- return ::fma(a,b,c);
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet4h,4>& kernel) {
+ __m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x);
+ __m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x);
+ __m64 T2 = _mm_unpackhi_pi16(kernel.packet[0].x, kernel.packet[1].x);
+ __m64 T3 = _mm_unpackhi_pi16(kernel.packet[2].x, kernel.packet[3].x);
+
+ kernel.packet[0].x = _mm_unpacklo_pi32(T0, T1);
+ kernel.packet[1].x = _mm_unpackhi_pi32(T0, T1);
+ kernel.packet[2].x = _mm_unpacklo_pi32(T2, T3);
+ kernel.packet[3].x = _mm_unpackhi_pi32(T2, T3);
}
+
#endif
+
} // end namespace internal
} // end namespace Eigen
+#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900
+// PGI++ does not define the following intrinsics in C++ mode.
+static inline __m128 _mm_castpd_ps (__m128d x) { return reinterpret_cast<__m128&>(x); }
+static inline __m128i _mm_castpd_si128(__m128d x) { return reinterpret_cast<__m128i&>(x); }
+static inline __m128d _mm_castps_pd (__m128 x) { return reinterpret_cast<__m128d&>(x); }
+static inline __m128i _mm_castps_si128(__m128 x) { return reinterpret_cast<__m128i&>(x); }
+static inline __m128 _mm_castsi128_ps(__m128i x) { return reinterpret_cast<__m128&>(x); }
+static inline __m128d _mm_castsi128_pd(__m128i x) { return reinterpret_cast<__m128d&>(x); }
+#endif
+
#endif // EIGEN_PACKET_MATH_SSE_H
diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h
index c84893230..d2a0037e0 100644
--- a/Eigen/src/Core/arch/SSE/TypeCasting.h
+++ b/Eigen/src/Core/arch/SSE/TypeCasting.h
@@ -14,6 +14,7 @@ namespace Eigen {
namespace internal {
+#ifndef EIGEN_VECTORIZE_AVX
template <>
struct type_casting_traits<float, int> {
enum {
@@ -23,11 +24,6 @@ struct type_casting_traits<float, int> {
};
};
-template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
- return _mm_cvttps_epi32(a);
-}
-
-
template <>
struct type_casting_traits<int, float> {
enum {
@@ -37,11 +33,6 @@ struct type_casting_traits<int, float> {
};
};
-template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
- return _mm_cvtepi32_ps(a);
-}
-
-
template <>
struct type_casting_traits<double, float> {
enum {
@@ -51,10 +42,6 @@ struct type_casting_traits<double, float> {
};
};
-template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet2d, Packet4f>(const Packet2d& a, const Packet2d& b) {
- return _mm_shuffle_ps(_mm_cvtpd_ps(a), _mm_cvtpd_ps(b), (1 << 2) | (1 << 6));
-}
-
template <>
struct type_casting_traits<float, double> {
enum {
@@ -63,12 +50,90 @@ struct type_casting_traits<float, double> {
TgtCoeffRatio = 2
};
};
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return _mm_cvttps_epi32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return _mm_cvtepi32_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet2d, Packet4f>(const Packet2d& a, const Packet2d& b) {
+ return _mm_shuffle_ps(_mm_cvtpd_ps(a), _mm_cvtpd_ps(b), (1 << 2) | (1 << 6));
+}
template<> EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f& a) {
// Simply discard the second half of the input
return _mm_cvtps_pd(a);
}
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
+ return _mm_castps_si128(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
+ return _mm_castsi128_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d,Packet4i>(const Packet4i& a) {
+ return _mm_castsi128_pd(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet2d>(const Packet2d& a) {
+ return _mm_castpd_si128(a);
+}
+
+// Disable the following code since it's broken on too many platforms / compilers.
+//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
+#if 0
+
+template <>
+struct type_casting_traits<Eigen::half, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4h, Packet4f>(const Packet4h& a) {
+ __int64_t a64 = _mm_cvtm64_si64(a.x);
+ Eigen::half h = raw_uint16_to_half(static_cast<unsigned short>(a64));
+ float f1 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 16));
+ float f2 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 32));
+ float f3 = static_cast<float>(h);
+ h = raw_uint16_to_half(static_cast<unsigned short>(a64 >> 48));
+ float f4 = static_cast<float>(h);
+ return _mm_set_ps(f4, f3, f2, f1);
+}
+
+template <>
+struct type_casting_traits<float, Eigen::half> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4h pcast<Packet4f, Packet4h>(const Packet4f& a) {
+ EIGEN_ALIGN16 float aux[4];
+ pstore(aux, a);
+ Eigen::half h0(aux[0]);
+ Eigen::half h1(aux[1]);
+ Eigen::half h2(aux[2]);
+ Eigen::half h3(aux[3]);
+
+ Packet4h result;
+ result.x = _mm_set_pi16(h3.x, h2.x, h1.x, h0.x);
+ return result;
+}
+
+#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/SVE/MathFunctions.h b/Eigen/src/Core/arch/SVE/MathFunctions.h
new file mode 100644
index 000000000..b139ea2e4
--- /dev/null
+++ b/Eigen/src/Core/arch/SVE/MathFunctions.h
@@ -0,0 +1,44 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020, Arm Limited and Contributors
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATH_FUNCTIONS_SVE_H
+#define EIGEN_MATH_FUNCTIONS_SVE_H
+
+namespace Eigen {
+namespace internal {
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pexp<PacketXf>(const PacketXf& x) {
+ return pexp_float(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf plog<PacketXf>(const PacketXf& x) {
+ return plog_float(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf psin<PacketXf>(const PacketXf& x) {
+ return psin_float(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pcos<PacketXf>(const PacketXf& x) {
+ return pcos_float(x);
+}
+
+// Hyperbolic Tangent function.
+template <>
+EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf ptanh<PacketXf>(const PacketXf& x) {
+ return internal::generic_fast_tanh_float(x);
+}
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_MATH_FUNCTIONS_SVE_H
diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h
new file mode 100644
index 000000000..9060b372f
--- /dev/null
+++ b/Eigen/src/Core/arch/SVE/PacketMath.h
@@ -0,0 +1,752 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020, Arm Limited and Contributors
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_PACKET_MATH_SVE_H
+#define EIGEN_PACKET_MATH_SVE_H
+
+namespace Eigen
+{
+namespace internal
+{
+#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
+#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
+#endif
+
+#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
+#endif
+
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
+
+template <typename Scalar, int SVEVectorLength>
+struct sve_packet_size_selector {
+ enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
+};
+
+/********************************* int32 **************************************/
+typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
+
+template <>
+struct packet_traits<numext::int32_t> : default_packet_traits {
+ typedef PacketXi type;
+ typedef PacketXi half; // Half not implemented yet
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasReduxp = 0 // Not implemented in SVE
+ };
+};
+
+template <>
+struct unpacket_traits<PacketXi> {
+ typedef numext::int32_t type;
+ typedef PacketXi half; // Half not yet implemented
+ enum {
+ size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
+ alignment = Aligned64,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr)
+{
+ svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from)
+{
+ return svdup_n_s32(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a)
+{
+ numext::int32_t c[packet_traits<numext::int32_t>::size];
+ for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
+ return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svadd_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svsub_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a)
+{
+ return svneg_s32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a)
+{
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svmul_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svdiv_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c)
+{
+ return svmla_s32_z(svptrue_b32(), c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svmin_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svmax_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/)
+{
+ return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/)
+{
+ return svdup_n_s32_z(svptrue_b32(), 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svand_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svorr_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return sveor_s32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b)
+{
+ return svbic_s32_z(svptrue_b32(), a, b);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a)
+{
+ return svasrd_n_s32_z(svptrue_b32(), a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a)
+{
+ return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N)));
+}
+
+template <int N>
+EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a)
+{
+ return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from)
+{
+ EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from)
+{
+ EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from)
+{
+ svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
+ return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from)
+{
+ svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
+ return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
+{
+ EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
+{
+ EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride)
+{
+ // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
+ svint32_t indices = svindex_s32(0, stride);
+ return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride)
+{
+ // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
+ svint32_t indices = svindex_s32(0, stride);
+ svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a)
+{
+ // svlasta returns the first element if all predicate bits are 0
+ return svlasta_s32(svpfalse_b(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a)
+{
+ return svrev_s32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a)
+{
+ return svabs_s32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a)
+{
+ return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
+}
+
+template <>
+EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a)
+{
+ EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
+ EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
+
+ // Multiply the vector by its reverse
+ svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
+ svint32_t half_prod;
+
+ // Extract the high half of the vector. Depending on the VL more reductions need to be done
+ if (EIGEN_ARM64_SVE_VL >= 2048) {
+ half_prod = svtbl_s32(prod, svindex_u32(32, 1));
+ prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 1024) {
+ half_prod = svtbl_s32(prod, svindex_u32(16, 1));
+ prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 512) {
+ half_prod = svtbl_s32(prod, svindex_u32(8, 1));
+ prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 256) {
+ half_prod = svtbl_s32(prod, svindex_u32(4, 1));
+ prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
+ }
+ // Last reduction
+ half_prod = svtbl_s32(prod, svindex_u32(2, 1));
+ prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
+
+ // The reduction is done to the first element.
+ return pfirst<PacketXi>(prod);
+}
+
+template <>
+EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a)
+{
+ return svminv_s32(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a)
+{
+ return svmaxv_s32(svptrue_b32(), a);
+}
+
+template <int N>
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
+ int buffer[packet_traits<numext::int32_t>::size * N] = {0};
+ int i = 0;
+
+ PacketXi stride_index = svindex_s32(0, N);
+
+ for (i = 0; i < N; i++) {
+ svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
+ }
+ for (i = 0; i < N; i++) {
+ kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
+ }
+}
+
+/********************************* float32 ************************************/
+
+typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
+
+template <>
+struct packet_traits<float> : default_packet_traits {
+ typedef PacketXf type;
+ typedef PacketXf half;
+
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 1,
+ HasSetLinear = 0,
+ HasBlend = 0,
+ HasReduxp = 0, // Not implemented in SVE
+
+ HasDiv = 1,
+ HasFloor = 1,
+
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH
+ };
+};
+
+template <>
+struct unpacket_traits<PacketXf> {
+ typedef float type;
+ typedef PacketXf half; // Half not yet implemented
+ typedef PacketXi integer_packet;
+
+ enum {
+ size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
+ alignment = Aligned64,
+ vectorizable = true,
+ masked_load_available = false,
+ masked_store_available = false
+ };
+};
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from)
+{
+ return svdup_n_f32(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a)
+{
+ float c[packet_traits<float>::size];
+ for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
+ return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svadd_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svsub_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a)
+{
+ return svneg_f32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a)
+{
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svmul_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svdiv_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c)
+{
+ return svmla_f32_z(svptrue_b32(), c, a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svmin_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return pmin<PacketXf>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svminnm_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svmax_f32_z(svptrue_b32(), a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return pmax<PacketXf>(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svmaxnm_f32_z(svptrue_b32(), a, b);
+}
+
+// Float comparisons in SVE return svbool (predicate). Use svdup to set active
+// lanes to 1 (0xffffffffu) and inactive lanes to 0.
+template <>
+EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
+}
+
+// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
+// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
+// active elements.
+template <>
+EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a)
+{
+ return svrintm_f32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/)
+{
+ return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
+}
+
+// Logical Operations are not supported for float, so reinterpret casts
+template <>
+EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b)
+{
+ return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from)
+{
+ EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from)
+{
+ EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from)
+{
+ svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
+ return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from)
+{
+ svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
+ indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
+ return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from)
+{
+ EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from)
+{
+ EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride)
+{
+ // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
+ svint32_t indices = svindex_s32(0, stride);
+ return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride)
+{
+ // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
+ svint32_t indices = svindex_s32(0, stride);
+ svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a)
+{
+ // svlasta returns the first element if all predicate bits are 0
+ return svlasta_f32(svpfalse_b(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a)
+{
+ return svrev_f32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
+{
+ return svabs_f32_z(svptrue_b32(), a);
+}
+
+// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
+// all vector extensions and the generic version.
+template <>
+EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
+{
+ return pfrexp_generic(a, exponent);
+}
+
+template <>
+EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a)
+{
+ return svaddv_f32(svptrue_b32(), a);
+}
+
+// Other reduction functions:
+// mul
+// Only works for SVE Vls multiple of 128
+template <>
+EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a)
+{
+ EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
+ EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
+ // Multiply the vector by its reverse
+ svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
+ svfloat32_t half_prod;
+
+ // Extract the high half of the vector. Depending on the VL more reductions need to be done
+ if (EIGEN_ARM64_SVE_VL >= 2048) {
+ half_prod = svtbl_f32(prod, svindex_u32(32, 1));
+ prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 1024) {
+ half_prod = svtbl_f32(prod, svindex_u32(16, 1));
+ prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 512) {
+ half_prod = svtbl_f32(prod, svindex_u32(8, 1));
+ prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
+ }
+ if (EIGEN_ARM64_SVE_VL >= 256) {
+ half_prod = svtbl_f32(prod, svindex_u32(4, 1));
+ prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
+ }
+ // Last reduction
+ half_prod = svtbl_f32(prod, svindex_u32(2, 1));
+ prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
+
+ // The reduction is done to the first element.
+ return pfirst<PacketXf>(prod);
+}
+
+template <>
+EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a)
+{
+ return svminv_f32(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a)
+{
+ return svmaxv_f32(svptrue_b32(), a);
+}
+
+template<int N>
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
+{
+ float buffer[packet_traits<float>::size * N] = {0};
+ int i = 0;
+
+ PacketXi stride_index = svindex_s32(0, N);
+
+ for (i = 0; i < N; i++) {
+ svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
+ }
+
+ for (i = 0; i < N; i++) {
+ kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
+ }
+}
+
+template<>
+EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
+{
+ return pldexp_generic(a, exponent);
+}
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_PACKET_MATH_SVE_H
diff --git a/Eigen/src/Core/arch/SVE/TypeCasting.h b/Eigen/src/Core/arch/SVE/TypeCasting.h
new file mode 100644
index 000000000..7ba5d9cd1
--- /dev/null
+++ b/Eigen/src/Core/arch/SVE/TypeCasting.h
@@ -0,0 +1,49 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020, Arm Limited and Contributors
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_SVE_H
+#define EIGEN_TYPE_CASTING_SVE_H
+
+namespace Eigen {
+namespace internal {
+
+template <>
+struct type_casting_traits<float, numext::int32_t> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+struct type_casting_traits<numext::int32_t, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_STRONG_INLINE PacketXf pcast<PacketXi, PacketXf>(const PacketXi& a) {
+ return svcvt_f32_s32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi pcast<PacketXf, PacketXi>(const PacketXf& a) {
+ return svcvt_s32_f32_z(svptrue_b32(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXf preinterpret<PacketXf, PacketXi>(const PacketXi& a) {
+ return svreinterpret_f32_s32(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE PacketXi preinterpret<PacketXi, PacketXf>(const PacketXf& a) {
+ return svreinterpret_s32_f32(a);
+}
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_SVE_H
diff --git a/Eigen/src/Core/arch/SYCL/InteropHeaders.h b/Eigen/src/Core/arch/SYCL/InteropHeaders.h
new file mode 100644
index 000000000..10856ff5e
--- /dev/null
+++ b/Eigen/src/Core/arch/SYCL/InteropHeaders.h
@@ -0,0 +1,232 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * InteropHeaders.h
+ *
+ * \brief:
+ * InteropHeaders
+ *
+ *****************************************************************/
+
+#ifndef EIGEN_INTEROP_HEADERS_SYCL_H
+#define EIGEN_INTEROP_HEADERS_SYCL_H
+
+namespace Eigen {
+
+#if !defined(EIGEN_DONT_VECTORIZE_SYCL)
+
+namespace internal {
+
+template <int has_blend, int lengths>
+struct sycl_packet_traits : default_packet_traits {
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = lengths,
+ HasHalfPacket = 0,
+ HasDiv = 1,
+ HasLog = 1,
+ HasExp = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasSin = 1,
+ HasCos = 1,
+ HasTan = 1,
+ HasASin = 1,
+ HasACos = 1,
+ HasATan = 1,
+ HasSinh = 1,
+ HasCosh = 1,
+ HasTanh = 1,
+ HasLGamma = 0,
+ HasDiGamma = 0,
+ HasZeta = 0,
+ HasPolygamma = 0,
+ HasErf = 0,
+ HasErfc = 0,
+ HasNdtri = 0,
+ HasIGamma = 0,
+ HasIGammac = 0,
+ HasBetaInc = 0,
+ HasBlend = has_blend,
+ // This flag is used to indicate whether packet comparison is supported.
+ // pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
+ HasCmp = 1,
+ HasMax = 1,
+ HasMin = 1,
+ HasMul = 1,
+ HasAdd = 1,
+ HasFloor = 1,
+ HasRound = 1,
+ HasRint = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasCeil = 1,
+ };
+};
+
+#ifdef SYCL_DEVICE_ONLY
+#define SYCL_PACKET_TRAITS(packet_type, has_blend, unpacket_type, lengths) \
+ template <> \
+ struct packet_traits<unpacket_type> \
+ : sycl_packet_traits<has_blend, lengths> { \
+ typedef packet_type type; \
+ typedef packet_type half; \
+ };
+
+SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, float, 4)
+SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, const float, 4)
+SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, double, 2)
+SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, const double, 2)
+#undef SYCL_PACKET_TRAITS
+
+// Make sure this is only available when targeting a GPU: we don't want to
+// introduce conflicts between these packet_traits definitions and the ones
+// we'll use on the host side (SSE, AVX, ...)
+#define SYCL_ARITHMETIC(packet_type) \
+ template <> \
+ struct is_arithmetic<packet_type> { \
+ enum { value = true }; \
+ };
+SYCL_ARITHMETIC(cl::sycl::cl_float4)
+SYCL_ARITHMETIC(cl::sycl::cl_double2)
+#undef SYCL_ARITHMETIC
+
+#define SYCL_UNPACKET_TRAITS(packet_type, unpacket_type, lengths) \
+ template <> \
+ struct unpacket_traits<packet_type> { \
+ typedef unpacket_type type; \
+ enum { size = lengths, vectorizable = true, alignment = Aligned16 }; \
+ typedef packet_type half; \
+ };
+SYCL_UNPACKET_TRAITS(cl::sycl::cl_float4, float, 4)
+SYCL_UNPACKET_TRAITS(cl::sycl::cl_double2, double, 2)
+
+#undef SYCL_UNPACKET_TRAITS
+#endif
+
+} // end namespace internal
+
+#endif
+
+namespace TensorSycl {
+namespace internal {
+
+template <typename PacketReturnType, int PacketSize>
+struct PacketWrapper;
+// This function should never get called on the device
+#ifndef SYCL_DEVICE_ONLY
+template <typename PacketReturnType, int PacketSize>
+struct PacketWrapper {
+ typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
+ Scalar;
+ template <typename Index>
+ EIGEN_DEVICE_FUNC static Scalar scalarize(Index, PacketReturnType &) {
+ eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE");
+ abort();
+ }
+ EIGEN_DEVICE_FUNC static PacketReturnType convert_to_packet_type(Scalar in,
+ Scalar) {
+ return ::Eigen::internal::template plset<PacketReturnType>(in);
+ }
+ EIGEN_DEVICE_FUNC static void set_packet(PacketReturnType, Scalar *) {
+ eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE");
+ abort();
+ }
+};
+
+#elif defined(SYCL_DEVICE_ONLY)
+template <typename PacketReturnType>
+struct PacketWrapper<PacketReturnType, 4> {
+ typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
+ Scalar;
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) {
+ switch (index) {
+ case 0:
+ return in.x();
+ case 1:
+ return in.y();
+ case 2:
+ return in.z();
+ case 3:
+ return in.w();
+ default:
+ //INDEX MUST BE BETWEEN 0 and 3.There is no abort function in SYCL kernel. so we cannot use abort here.
+ // The code will never reach here
+ __builtin_unreachable();
+ }
+ __builtin_unreachable();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(
+ Scalar in, Scalar other) {
+ return PacketReturnType(in, other, other, other);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
+ lhs = PacketReturnType(rhs[0], rhs[1], rhs[2], rhs[3]);
+ }
+};
+
+template <typename PacketReturnType>
+struct PacketWrapper<PacketReturnType, 1> {
+ typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
+ Scalar;
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index, PacketReturnType &in) {
+ return in;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(Scalar in,
+ Scalar) {
+ return PacketReturnType(in);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
+ lhs = rhs[0];
+ }
+};
+
+template <typename PacketReturnType>
+struct PacketWrapper<PacketReturnType, 2> {
+ typedef typename ::Eigen::internal::unpacket_traits<PacketReturnType>::type
+ Scalar;
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) {
+ switch (index) {
+ case 0:
+ return in.x();
+ case 1:
+ return in.y();
+ default:
+ //INDEX MUST BE BETWEEN 0 and 1.There is no abort function in SYCL kernel. so we cannot use abort here.
+ // The code will never reach here
+ __builtin_unreachable();
+ }
+ __builtin_unreachable();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(
+ Scalar in, Scalar other) {
+ return PacketReturnType(in, other);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) {
+ lhs = PacketReturnType(rhs[0], rhs[1]);
+ }
+};
+
+#endif
+
+} // end namespace internal
+} // end namespace TensorSycl
+} // end namespace Eigen
+
+#endif // EIGEN_INTEROP_HEADERS_SYCL_H
diff --git a/Eigen/src/Core/arch/SYCL/MathFunctions.h b/Eigen/src/Core/arch/SYCL/MathFunctions.h
new file mode 100644
index 000000000..2ab0f2a76
--- /dev/null
+++ b/Eigen/src/Core/arch/SYCL/MathFunctions.h
@@ -0,0 +1,301 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * MathFunctions.h
+ *
+ * \brief:
+ * MathFunctions
+ *
+ *****************************************************************/
+
+#ifndef EIGEN_MATH_FUNCTIONS_SYCL_H
+#define EIGEN_MATH_FUNCTIONS_SYCL_H
+namespace Eigen {
+
+namespace internal {
+
+// Make sure this is only available when targeting a GPU: we don't want to
+// introduce conflicts between these packet_traits definitions and the ones
+// we'll use on the host side (SSE, AVX, ...)
+#if defined(SYCL_DEVICE_ONLY)
+#define SYCL_PLOG(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::log(a); \
+ }
+
+SYCL_PLOG(cl::sycl::cl_float4)
+SYCL_PLOG(cl::sycl::cl_double2)
+#undef SYCL_PLOG
+
+#define SYCL_PLOG1P(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog1p<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::log1p(a); \
+ }
+
+SYCL_PLOG1P(cl::sycl::cl_float4)
+SYCL_PLOG1P(cl::sycl::cl_double2)
+#undef SYCL_PLOG1P
+
+#define SYCL_PLOG10(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog10<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::log10(a); \
+ }
+
+SYCL_PLOG10(cl::sycl::cl_float4)
+SYCL_PLOG10(cl::sycl::cl_double2)
+#undef SYCL_PLOG10
+
+#define SYCL_PEXP(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexp<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::exp(a); \
+ }
+
+SYCL_PEXP(cl::sycl::cl_float4)
+SYCL_PEXP(cl::sycl::cl_float)
+SYCL_PEXP(cl::sycl::cl_double2)
+#undef SYCL_PEXP
+
+#define SYCL_PEXPM1(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexpm1<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::expm1(a); \
+ }
+
+SYCL_PEXPM1(cl::sycl::cl_float4)
+SYCL_PEXPM1(cl::sycl::cl_double2)
+#undef SYCL_PEXPM1
+
+#define SYCL_PSQRT(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psqrt<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::sqrt(a); \
+ }
+
+SYCL_PSQRT(cl::sycl::cl_float4)
+SYCL_PSQRT(cl::sycl::cl_double2)
+#undef SYCL_PSQRT
+
+#define SYCL_PRSQRT(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type prsqrt<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::rsqrt(a); \
+ }
+
+SYCL_PRSQRT(cl::sycl::cl_float4)
+SYCL_PRSQRT(cl::sycl::cl_double2)
+#undef SYCL_PRSQRT
+
+/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
+#define SYCL_PSIN(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psin<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::sin(a); \
+ }
+
+SYCL_PSIN(cl::sycl::cl_float4)
+SYCL_PSIN(cl::sycl::cl_double2)
+#undef SYCL_PSIN
+
+/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
+#define SYCL_PCOS(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcos<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::cos(a); \
+ }
+
+SYCL_PCOS(cl::sycl::cl_float4)
+SYCL_PCOS(cl::sycl::cl_double2)
+#undef SYCL_PCOS
+
+/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
+#define SYCL_PTAN(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptan<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::tan(a); \
+ }
+
+SYCL_PTAN(cl::sycl::cl_float4)
+SYCL_PTAN(cl::sycl::cl_double2)
+#undef SYCL_PTAN
+
+/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
+#define SYCL_PASIN(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pasin<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::asin(a); \
+ }
+
+SYCL_PASIN(cl::sycl::cl_float4)
+SYCL_PASIN(cl::sycl::cl_double2)
+#undef SYCL_PASIN
+
+/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
+#define SYCL_PACOS(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pacos<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::acos(a); \
+ }
+
+SYCL_PACOS(cl::sycl::cl_float4)
+SYCL_PACOS(cl::sycl::cl_double2)
+#undef SYCL_PACOS
+
+/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
+#define SYCL_PATAN(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type patan<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::atan(a); \
+ }
+
+SYCL_PATAN(cl::sycl::cl_float4)
+SYCL_PATAN(cl::sycl::cl_double2)
+#undef SYCL_PATAN
+
+/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */
+#define SYCL_PSINH(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psinh<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::sinh(a); \
+ }
+
+SYCL_PSINH(cl::sycl::cl_float4)
+SYCL_PSINH(cl::sycl::cl_double2)
+#undef SYCL_PSINH
+
+/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */
+#define SYCL_PCOSH(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcosh<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::cosh(a); \
+ }
+
+SYCL_PCOSH(cl::sycl::cl_float4)
+SYCL_PCOSH(cl::sycl::cl_double2)
+#undef SYCL_PCOSH
+
+/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */
+#define SYCL_PTANH(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptanh<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::tanh(a); \
+ }
+
+SYCL_PTANH(cl::sycl::cl_float4)
+SYCL_PTANH(cl::sycl::cl_double2)
+#undef SYCL_PTANH
+
+#define SYCL_PCEIL(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pceil<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::ceil(a); \
+ }
+
+SYCL_PCEIL(cl::sycl::cl_float4)
+SYCL_PCEIL(cl::sycl::cl_double2)
+#undef SYCL_PCEIL
+
+#define SYCL_PROUND(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pround<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::round(a); \
+ }
+
+SYCL_PROUND(cl::sycl::cl_float4)
+SYCL_PROUND(cl::sycl::cl_double2)
+#undef SYCL_PROUND
+
+#define SYCL_PRINT(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type print<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::rint(a); \
+ }
+
+SYCL_PRINT(cl::sycl::cl_float4)
+SYCL_PRINT(cl::sycl::cl_double2)
+#undef SYCL_PRINT
+
+#define SYCL_FLOOR(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pfloor<packet_type>( \
+ const packet_type& a) { \
+ return cl::sycl::floor(a); \
+ }
+
+SYCL_FLOOR(cl::sycl::cl_float4)
+SYCL_FLOOR(cl::sycl::cl_double2)
+#undef SYCL_FLOOR
+
+#define SYCL_PMIN(packet_type, expr) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmin<packet_type>( \
+ const packet_type& a, const packet_type& b) { \
+ return expr; \
+ }
+
+SYCL_PMIN(cl::sycl::cl_float4, cl::sycl::fmin(a, b))
+SYCL_PMIN(cl::sycl::cl_double2, cl::sycl::fmin(a, b))
+#undef SYCL_PMIN
+
+#define SYCL_PMAX(packet_type, expr) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmax<packet_type>( \
+ const packet_type& a, const packet_type& b) { \
+ return expr; \
+ }
+
+SYCL_PMAX(cl::sycl::cl_float4, cl::sycl::fmax(a, b))
+SYCL_PMAX(cl::sycl::cl_double2, cl::sycl::fmax(a, b))
+#undef SYCL_PMAX
+
+#define SYCL_PLDEXP(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pldexp( \
+ const packet_type& a, const packet_type& exponent) { \
+ return cl::sycl::ldexp( \
+ a, exponent.template convert<cl::sycl::cl_int, \
+ cl::sycl::rounding_mode::automatic>()); \
+ }
+
+SYCL_PLDEXP(cl::sycl::cl_float4)
+SYCL_PLDEXP(cl::sycl::cl_double2)
+#undef SYCL_PLDEXP
+
+#endif
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATH_FUNCTIONS_SYCL_H
diff --git a/Eigen/src/Core/arch/SYCL/PacketMath.h b/Eigen/src/Core/arch/SYCL/PacketMath.h
new file mode 100644
index 000000000..87badc076
--- /dev/null
+++ b/Eigen/src/Core/arch/SYCL/PacketMath.h
@@ -0,0 +1,670 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * PacketMath.h
+ *
+ * \brief:
+ * PacketMath
+ *
+ *****************************************************************/
+
+#ifndef EIGEN_PACKET_MATH_SYCL_H
+#define EIGEN_PACKET_MATH_SYCL_H
+#include <type_traits>
+namespace Eigen {
+
+namespace internal {
+#ifdef SYCL_DEVICE_ONLY
+
+#define SYCL_PLOADT_RO(address_space_target) \
+ template <typename packet_type, int Alignment> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt_ro( \
+ typename cl::sycl::multi_ptr< \
+ const typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ from) { \
+ typedef typename unpacket_traits<packet_type>::type scalar; \
+ typedef cl::sycl::multi_ptr< \
+ scalar, cl::sycl::access::address_space::address_space_target> \
+ multi_ptr; \
+ auto res = packet_type( \
+ static_cast<typename unpacket_traits<packet_type>::type>(0)); \
+ res.load(0, multi_ptr(const_cast<typename multi_ptr::pointer_t>(from))); \
+ return res; \
+ }
+
+SYCL_PLOADT_RO(global_space)
+SYCL_PLOADT_RO(local_space)
+#undef SYCL_PLOADT_RO
+#endif
+
+template <typename packet_type, int Alignment, typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
+ploadt_ro(const Eigen::TensorSycl::internal::RangeAccess<
+ cl::sycl::access::mode::read_write, T>& from) {
+ return ploadt_ro<packet_type, Alignment>(from.get_pointer());
+}
+
+#ifdef SYCL_DEVICE_ONLY
+#define SYCL_PLOAD(address_space_target, Alignment, AlignedType) \
+ template <typename packet_type> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \
+ typename cl::sycl::multi_ptr< \
+ const typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ from) { \
+ return ploadt_ro<packet_type, Alignment>(from); \
+ }
+
+// global space
+SYCL_PLOAD(global_space, Unaligned, u)
+SYCL_PLOAD(global_space, Aligned, )
+// local space
+SYCL_PLOAD(local_space, Unaligned, u)
+SYCL_PLOAD(local_space, Aligned, )
+
+#undef SYCL_PLOAD
+#endif
+
+#define SYCL_PLOAD(Alignment, AlignedType) \
+ template <typename packet_type> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \
+ const Eigen::TensorSycl::internal::RangeAccess< \
+ cl::sycl::access::mode::read_write, \
+ typename unpacket_traits<packet_type>::type> \
+ from) { \
+ return ploadt_ro<packet_type, Alignment>(from); \
+ }
+SYCL_PLOAD(Unaligned, u)
+SYCL_PLOAD(Aligned, )
+#undef SYCL_PLOAD
+
+#ifdef SYCL_DEVICE_ONLY
+/** \internal \returns a packet version of \a *from.
+ * The pointer \a from must be aligned on a \a Alignment bytes boundary. */
+#define SYCL_PLOADT(address_space_target) \
+ template <typename packet_type, int Alignment> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt( \
+ typename cl::sycl::multi_ptr< \
+ const typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ from) { \
+ if (Alignment >= unpacket_traits<packet_type>::alignment) \
+ return pload<packet_type>(from); \
+ else \
+ return ploadu<packet_type>(from); \
+ }
+
+// global space
+SYCL_PLOADT(global_space)
+// local space
+SYCL_PLOADT(local_space)
+#undef SYCL_PLOADT
+#endif
+
+template <typename packet_type, int Alignment>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
+ploadt(const Eigen::TensorSycl::internal::RangeAccess<
+ cl::sycl::access::mode::read_write,
+ typename unpacket_traits<packet_type>::type>& from) {
+ return ploadt<packet_type, Alignment>(from.get_pointer());
+}
+#ifdef SYCL_DEVICE_ONLY
+
+// private_space
+#define SYCL_PLOADT_RO_SPECIAL(packet_type, Alignment) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type \
+ ploadt_ro<packet_type, Alignment>( \
+ const typename unpacket_traits<packet_type>::type* from) { \
+ typedef typename unpacket_traits<packet_type>::type scalar; \
+ auto res = packet_type(static_cast<scalar>(0)); \
+ res.template load<cl::sycl::access::address_space::private_space>( \
+ 0, const_cast<scalar*>(from)); \
+ return res; \
+ }
+
+SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Aligned)
+SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Aligned)
+SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Unaligned)
+SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Unaligned)
+
+#define SYCL_PLOAD_SPECIAL(packet_type, alignment_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##alignment_type( \
+ const typename unpacket_traits<packet_type>::type* from) { \
+ typedef typename unpacket_traits<packet_type>::type scalar; \
+ auto res = packet_type(static_cast<scalar>(0)); \
+ res.template load<cl::sycl::access::address_space::private_space>( \
+ 0, const_cast<scalar*>(from)); \
+ return res; \
+ }
+SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, )
+SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, )
+SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, u)
+SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, u)
+
+#undef SYCL_PLOAD_SPECIAL
+
+#define SYCL_PSTORE(scalar, packet_type, address_space_target, alignment) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
+ typename cl::sycl::multi_ptr< \
+ scalar, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ to, \
+ const packet_type& from) { \
+ typedef cl::sycl::multi_ptr< \
+ scalar, cl::sycl::access::address_space::address_space_target> \
+ multi_ptr; \
+ from.store(0, multi_ptr(to)); \
+ }
+
+// global space
+SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, )
+SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, u)
+SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, )
+SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, u)
+SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, )
+SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, u)
+SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, )
+SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, u)
+
+SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, )
+SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, u)
+SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, )
+SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, u)
+#undef SYCL_PSTORE
+
+#define SYCL_PSTORE_T(address_space_target) \
+ template <typename scalar, typename packet_type, int Alignment> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret( \
+ typename cl::sycl::multi_ptr< \
+ scalar, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ to, \
+ const packet_type& from) { \
+ if (Alignment) \
+ pstore(to, from); \
+ else \
+ pstoreu(to, from); \
+ }
+
+SYCL_PSTORE_T(global_space)
+
+SYCL_PSTORE_T(local_space)
+
+#undef SYCL_PSTORE_T
+
+#define SYCL_PSET1(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1<packet_type>( \
+ const typename unpacket_traits<packet_type>::type& from) { \
+ return packet_type(from); \
+ }
+
+// global space
+SYCL_PSET1(cl::sycl::cl_float4)
+SYCL_PSET1(cl::sycl::cl_double2)
+
+#undef SYCL_PSET1
+
+template <typename packet_type>
+struct get_base_packet {
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
+ get_ploaddup(sycl_multi_pointer) {}
+
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
+ get_pgather(sycl_multi_pointer, Index) {}
+};
+
+template <>
+struct get_base_packet<cl::sycl::cl_float4> {
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_ploaddup(
+ sycl_multi_pointer from) {
+ return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]);
+ }
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_pgather(
+ sycl_multi_pointer from, Index stride) {
+ return cl::sycl::cl_float4(from[0 * stride], from[1 * stride],
+ from[2 * stride], from[3 * stride]);
+ }
+
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
+ sycl_multi_pointer to, const cl::sycl::cl_float4& from, Index stride) {
+ auto tmp = stride;
+ to[0] = from.x();
+ to[tmp] = from.y();
+ to[tmp += stride] = from.z();
+ to[tmp += stride] = from.w();
+ }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 set_plset(
+ const float& a) {
+ return cl::sycl::cl_float4(static_cast<float>(a), static_cast<float>(a + 1),
+ static_cast<float>(a + 2),
+ static_cast<float>(a + 3));
+ }
+};
+
+template <>
+struct get_base_packet<cl::sycl::cl_double2> {
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2
+ get_ploaddup(const sycl_multi_pointer from) {
+ return cl::sycl::cl_double2(from[0], from[0]);
+ }
+
+ template <typename sycl_multi_pointer, typename Index>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 get_pgather(
+ const sycl_multi_pointer from, Index stride) {
+ return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]);
+ }
+
+ template <typename sycl_multi_pointer>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
+ sycl_multi_pointer to, const cl::sycl::cl_double2& from, Index stride) {
+ to[0] = from.x();
+ to[stride] = from.y();
+ }
+
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 set_plset(
+ const double& a) {
+ return cl::sycl::cl_double2(static_cast<double>(a),
+ static_cast<double>(a + 1));
+ }
+};
+
+#define SYCL_PLOAD_DUP(address_space_target) \
+ template <typename packet_type> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup( \
+ typename cl::sycl::multi_ptr< \
+ const typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ from) { \
+ return get_base_packet<packet_type>::get_ploaddup(from); \
+ }
+
+// global space
+SYCL_PLOAD_DUP(global_space)
+// local_space
+SYCL_PLOAD_DUP(local_space)
+#undef SYCL_PLOAD_DUP
+
+#define SYCL_PLOAD_DUP_SPECILIZE(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup<packet_type>( \
+ const typename unpacket_traits<packet_type>::type* from) { \
+ return get_base_packet<packet_type>::get_ploaddup(from); \
+ }
+
+SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4)
+SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2)
+
+#undef SYCL_PLOAD_DUP_SPECILIZE
+
+#define SYCL_PLSET(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset<packet_type>( \
+ const typename unpacket_traits<packet_type>::type& a) { \
+ return get_base_packet<packet_type>::set_plset(a); \
+ }
+
+SYCL_PLSET(cl::sycl::cl_float4)
+SYCL_PLSET(cl::sycl::cl_double2)
+
+#undef SYCL_PLSET
+
+#define SYCL_PGATHER(address_space_target) \
+ template <typename Scalar, typename packet_type> \
+ EIGEN_DEVICE_FUNC inline packet_type pgather( \
+ typename cl::sycl::multi_ptr< \
+ const typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ from, \
+ Index stride) { \
+ return get_base_packet<packet_type>::get_pgather(from, stride); \
+ }
+
+// global space
+SYCL_PGATHER(global_space)
+// local space
+SYCL_PGATHER(local_space)
+
+#undef SYCL_PGATHER
+
+#define SYCL_PGATHER_SPECILIZE(scalar, packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type \
+ pgather<scalar, packet_type>( \
+ const typename unpacket_traits<packet_type>::type* from, Index stride) { \
+ return get_base_packet<packet_type>::get_pgather(from, stride); \
+ }
+
+SYCL_PGATHER_SPECILIZE(float, cl::sycl::cl_float4)
+SYCL_PGATHER_SPECILIZE(double, cl::sycl::cl_double2)
+
+#undef SYCL_PGATHER_SPECILIZE
+
+#define SYCL_PSCATTER(address_space_target) \
+ template <typename Scalar, typename packet_type> \
+ EIGEN_DEVICE_FUNC inline void pscatter( \
+ typename cl::sycl::multi_ptr< \
+ typename unpacket_traits<packet_type>::type, \
+ cl::sycl::access::address_space::address_space_target>::pointer_t \
+ to, \
+ const packet_type& from, Index stride) { \
+ get_base_packet<packet_type>::set_pscatter(to, from, stride); \
+ }
+
+// global space
+SYCL_PSCATTER(global_space)
+// local space
+SYCL_PSCATTER(local_space)
+
+#undef SYCL_PSCATTER
+
+#define SYCL_PSCATTER_SPECILIZE(scalar, packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<scalar, packet_type>( \
+ typename unpacket_traits<packet_type>::type * to, \
+ const packet_type& from, Index stride) { \
+ get_base_packet<packet_type>::set_pscatter(to, from, stride); \
+ }
+
+SYCL_PSCATTER_SPECILIZE(float, cl::sycl::cl_float4)
+SYCL_PSCATTER_SPECILIZE(double, cl::sycl::cl_double2)
+
+#undef SYCL_PSCATTER_SPECILIZE
+
+#define SYCL_PMAD(packet_type) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd( \
+ const packet_type& a, const packet_type& b, const packet_type& c) { \
+ return cl::sycl::mad(a, b, c); \
+ }
+
+SYCL_PMAD(cl::sycl::cl_float4)
+SYCL_PMAD(cl::sycl::cl_double2)
+#undef SYCL_PMAD
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float pfirst<cl::sycl::cl_float4>(
+ const cl::sycl::cl_float4& a) {
+ return a.x();
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double pfirst<cl::sycl::cl_double2>(
+ const cl::sycl::cl_double2& a) {
+ return a.x();
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux<cl::sycl::cl_float4>(
+ const cl::sycl::cl_float4& a) {
+ return a.x() + a.y() + a.z() + a.w();
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux<cl::sycl::cl_double2>(
+ const cl::sycl::cl_double2& a) {
+ return a.x() + a.y();
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_max<cl::sycl::cl_float4>(
+ const cl::sycl::cl_float4& a) {
+ return cl::sycl::fmax(cl::sycl::fmax(a.x(), a.y()),
+ cl::sycl::fmax(a.z(), a.w()));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_max<cl::sycl::cl_double2>(
+ const cl::sycl::cl_double2& a) {
+ return cl::sycl::fmax(a.x(), a.y());
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_min<cl::sycl::cl_float4>(
+ const cl::sycl::cl_float4& a) {
+ return cl::sycl::fmin(cl::sycl::fmin(a.x(), a.y()),
+ cl::sycl::fmin(a.z(), a.w()));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_min<cl::sycl::cl_double2>(
+ const cl::sycl::cl_double2& a) {
+ return cl::sycl::fmin(a.x(), a.y());
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_mul<cl::sycl::cl_float4>(
+ const cl::sycl::cl_float4& a) {
+ return a.x() * a.y() * a.z() * a.w();
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_mul<cl::sycl::cl_double2>(
+ const cl::sycl::cl_double2& a) {
+ return a.x() * a.y();
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
+pabs<cl::sycl::cl_float4>(const cl::sycl::cl_float4& a) {
+ return cl::sycl::cl_float4(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()),
+ cl::sycl::fabs(a.z()), cl::sycl::fabs(a.w()));
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2
+pabs<cl::sycl::cl_double2>(const cl::sycl::cl_double2& a) {
+ return cl::sycl::cl_double2(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()));
+}
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_le(const Packet &a,
+ const Packet &b) {
+ return ((a <= b)
+ .template convert<typename unpacket_traits<Packet>::type,
+ cl::sycl::rounding_mode::automatic>());
+}
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_lt(const Packet &a,
+ const Packet &b) {
+ return ((a < b)
+ .template convert<typename unpacket_traits<Packet>::type,
+ cl::sycl::rounding_mode::automatic>());
+}
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_eq(const Packet &a,
+ const Packet &b) {
+ return ((a == b)
+ .template convert<typename unpacket_traits<Packet>::type,
+ cl::sycl::rounding_mode::automatic>());
+}
+
+#define SYCL_PCMP(OP, TYPE) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP<TYPE>(const TYPE &a, \
+ const TYPE &b) { \
+ return sycl_pcmp_##OP<TYPE>(a, b); \
+ }
+
+SYCL_PCMP(le, cl::sycl::cl_float4)
+SYCL_PCMP(lt, cl::sycl::cl_float4)
+SYCL_PCMP(eq, cl::sycl::cl_float4)
+SYCL_PCMP(le, cl::sycl::cl_double2)
+SYCL_PCMP(lt, cl::sycl::cl_double2)
+SYCL_PCMP(eq, cl::sycl::cl_double2)
+#undef SYCL_PCMP
+
+template <typename T> struct convert_to_integer;
+
+template <> struct convert_to_integer<float> {
+ using type = std::int32_t;
+ using packet_type = cl::sycl::cl_int4;
+};
+template <> struct convert_to_integer<double> {
+ using type = std::int64_t;
+ using packet_type = cl::sycl::cl_long2;
+};
+
+template <typename PacketIn>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename convert_to_integer<
+ typename unpacket_traits<PacketIn>::type>::packet_type
+vector_as_int(const PacketIn &p) {
+ return (
+ p.template convert<typename convert_to_integer<
+ typename unpacket_traits<PacketIn>::type>::type,
+ cl::sycl::rounding_mode::automatic>());
+}
+
+template <typename packetOut, typename PacketIn>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetOut
+convert_vector(const PacketIn &p) {
+ return (p.template convert<typename unpacket_traits<packetOut>::type,
+ cl::sycl::rounding_mode::automatic>());
+}
+
+#define SYCL_PAND(TYPE) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pand<TYPE>(const TYPE &a, \
+ const TYPE &b) { \
+ return convert_vector<TYPE>(vector_as_int(a) & vector_as_int(b)); \
+ }
+SYCL_PAND(cl::sycl::cl_float4)
+SYCL_PAND(cl::sycl::cl_double2)
+#undef SYCL_PAND
+
+#define SYCL_POR(TYPE) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE por<TYPE>(const TYPE &a, \
+ const TYPE &b) { \
+ return convert_vector<TYPE>(vector_as_int(a) | vector_as_int(b)); \
+ }
+
+SYCL_POR(cl::sycl::cl_float4)
+SYCL_POR(cl::sycl::cl_double2)
+#undef SYCL_POR
+
+#define SYCL_PXOR(TYPE) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pxor<TYPE>(const TYPE &a, \
+ const TYPE &b) { \
+ return convert_vector<TYPE>(vector_as_int(a) ^ vector_as_int(b)); \
+ }
+
+SYCL_PXOR(cl::sycl::cl_float4)
+SYCL_PXOR(cl::sycl::cl_double2)
+#undef SYCL_PXOR
+
+#define SYCL_PANDNOT(TYPE) \
+ template <> \
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pandnot<TYPE>(const TYPE &a, \
+ const TYPE &b) { \
+ return convert_vector<TYPE>(vector_as_int(a) & (~vector_as_int(b))); \
+ }
+SYCL_PANDNOT(cl::sycl::cl_float4)
+SYCL_PANDNOT(cl::sycl::cl_double2)
+#undef SYCL_PANDNOT
+
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
+ PacketBlock<cl::sycl::cl_float4, 4>& kernel) {
+ float tmp = kernel.packet[0].y();
+ kernel.packet[0].y() = kernel.packet[1].x();
+ kernel.packet[1].x() = tmp;
+
+ tmp = kernel.packet[0].z();
+ kernel.packet[0].z() = kernel.packet[2].x();
+ kernel.packet[2].x() = tmp;
+
+ tmp = kernel.packet[0].w();
+ kernel.packet[0].w() = kernel.packet[3].x();
+ kernel.packet[3].x() = tmp;
+
+ tmp = kernel.packet[1].z();
+ kernel.packet[1].z() = kernel.packet[2].y();
+ kernel.packet[2].y() = tmp;
+
+ tmp = kernel.packet[1].w();
+ kernel.packet[1].w() = kernel.packet[3].y();
+ kernel.packet[3].y() = tmp;
+
+ tmp = kernel.packet[2].w();
+ kernel.packet[2].w() = kernel.packet[3].z();
+ kernel.packet[3].z() = tmp;
+}
+
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
+ PacketBlock<cl::sycl::cl_double2, 2>& kernel) {
+ double tmp = kernel.packet[0].y();
+ kernel.packet[0].y() = kernel.packet[1].x();
+ kernel.packet[1].x() = tmp;
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 pblend(
+ const Selector<unpacket_traits<cl::sycl::cl_float4>::size>& ifPacket,
+ const cl::sycl::cl_float4& thenPacket,
+ const cl::sycl::cl_float4& elsePacket) {
+ cl::sycl::cl_int4 condition(
+ ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1,
+ ifPacket.select[2] ? 0 : -1, ifPacket.select[3] ? 0 : -1);
+ return cl::sycl::select(thenPacket, elsePacket, condition);
+}
+
+template <>
+inline cl::sycl::cl_double2 pblend(
+ const Selector<unpacket_traits<cl::sycl::cl_double2>::size>& ifPacket,
+ const cl::sycl::cl_double2& thenPacket,
+ const cl::sycl::cl_double2& elsePacket) {
+ cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1,
+ ifPacket.select[1] ? 0 : -1);
+ return cl::sycl::select(thenPacket, elsePacket, condition);
+}
+#endif // SYCL_DEVICE_ONLY
+
+#define SYCL_PSTORE(alignment) \
+ template <typename packet_type> \
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
+ const Eigen::TensorSycl::internal::RangeAccess< \
+ cl::sycl::access::mode::read_write, \
+ typename unpacket_traits<packet_type>::type>& to, \
+ const packet_type& from) { \
+ pstore##alignment(to.get_pointer(), from); \
+ }
+
+// global space
+SYCL_PSTORE()
+SYCL_PSTORE(u)
+
+#undef SYCL_PSTORE
+
+template <typename scalar, typename packet_type, int Alignment>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(
+ Eigen::TensorSycl::internal::RangeAccess<
+ cl::sycl::access::mode::read_write,
+ typename unpacket_traits<packet_type>::type>
+ to,
+ const packet_type& from) {
+ pstoret<scalar, packet_type, Alignment>(to.get_pointer(), from);
+}
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_PACKET_MATH_SYCL_H
diff --git a/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h
new file mode 100644
index 000000000..f81e59db5
--- /dev/null
+++ b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h
@@ -0,0 +1,694 @@
+/***************************************************************************
+ * Copyright (C) 2017 Codeplay Software Limited
+ * This Source Code Form is subject to the terms of the Mozilla
+ * Public License v. 2.0. If a copy of the MPL was not distributed
+ * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+ *
+ *
+ * SyclMemoryModel.h
+ *
+ * Description:
+ * Interface for SYCL buffers to behave as a non-dereferenceable pointer
+ * Interface for Placeholder accessor to behave as a pointer on both host
+ * and device
+ *
+ * Authors:
+ *
+ * Ruyman Reyes Codeplay Software Ltd.
+ * Mehdi Goli Codeplay Software Ltd.
+ * Vanya Yaneva Codeplay Software Ltd.
+ *
+ **************************************************************************/
+
+#if defined(EIGEN_USE_SYCL) && \
+ !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
+#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
+
+#include <CL/sycl.hpp>
+#ifdef EIGEN_EXCEPTIONS
+#include <stdexcept>
+#endif
+#include <cstddef>
+#include <queue>
+#include <set>
+#include <unordered_map>
+
+namespace Eigen {
+namespace TensorSycl {
+namespace internal {
+
+using sycl_acc_target = cl::sycl::access::target;
+using sycl_acc_mode = cl::sycl::access::mode;
+
+/**
+ * Default values for template arguments
+ */
+using buffer_data_type_t = uint8_t;
+const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
+const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
+
+/**
+ * PointerMapper
+ * Associates fake pointers with buffers.
+ *
+ */
+class PointerMapper {
+ public:
+ using base_ptr_t = std::intptr_t;
+
+ /* Structure of a virtual pointer
+ *
+ * |================================================|
+ * | POINTER ADDRESS |
+ * |================================================|
+ */
+ struct virtual_pointer_t {
+ /* Type for the pointers
+ */
+ base_ptr_t m_contents;
+
+ /** Conversions from virtual_pointer_t to
+ * void * should just reinterpret_cast the integer number
+ */
+ operator void *() const { return reinterpret_cast<void *>(m_contents); }
+
+ /**
+ * Convert back to the integer number.
+ */
+ operator base_ptr_t() const { return m_contents; }
+
+ /**
+ * Add a certain value to the pointer to create a
+ * new pointer to that offset
+ */
+ virtual_pointer_t operator+(size_t off) { return m_contents + off; }
+
+ /* Numerical order for sorting pointers in containers. */
+ bool operator<(virtual_pointer_t rhs) const {
+ return (static_cast<base_ptr_t>(m_contents) <
+ static_cast<base_ptr_t>(rhs.m_contents));
+ }
+
+ bool operator>(virtual_pointer_t rhs) const {
+ return (static_cast<base_ptr_t>(m_contents) >
+ static_cast<base_ptr_t>(rhs.m_contents));
+ }
+
+ /**
+ * Numerical order for sorting pointers in containers
+ */
+ bool operator==(virtual_pointer_t rhs) const {
+ return (static_cast<base_ptr_t>(m_contents) ==
+ static_cast<base_ptr_t>(rhs.m_contents));
+ }
+
+ /**
+ * Simple forward to the equality overload.
+ */
+ bool operator!=(virtual_pointer_t rhs) const {
+ return !(this->operator==(rhs));
+ }
+
+ /**
+ * Converts a void * into a virtual pointer structure.
+ * Note that this will only work if the void * was
+ * already a virtual_pointer_t, but we have no way of
+ * checking
+ */
+ virtual_pointer_t(const void *ptr)
+ : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
+
+ /**
+ * Creates a virtual_pointer_t from the given integer
+ * number
+ */
+ virtual_pointer_t(base_ptr_t u) : m_contents(u){};
+ };
+
+ /* Definition of a null pointer
+ */
+ const virtual_pointer_t null_virtual_ptr = nullptr;
+
+ /**
+ * Whether if a pointer is null or not.
+ * A pointer is nullptr if the value is of null_virtual_ptr
+ */
+ static inline bool is_nullptr(virtual_pointer_t ptr) {
+ return (static_cast<void *>(ptr) == nullptr);
+ }
+
+ /* basic type for all buffers
+ */
+ using buffer_t = cl::sycl::buffer_mem;
+
+ /**
+ * Node that stores information about a device allocation.
+ * Nodes are sorted by size to organise a free list of nodes
+ * that can be recovered.
+ */
+ struct pMapNode_t {
+ buffer_t m_buffer;
+ size_t m_size;
+ bool m_free;
+
+ pMapNode_t(buffer_t b, size_t size, bool f)
+ : m_buffer{b}, m_size{size}, m_free{f} {
+ m_buffer.set_final_data(nullptr);
+ }
+
+ bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
+ };
+
+ /** Storage of the pointer / buffer tree
+ */
+ using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
+
+ /**
+ * Obtain the insertion point in the pointer map for
+ * a pointer of the given size.
+ * \param requiredSize Size attemted to reclaim
+ */
+ typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
+ typename pointerMap_t::iterator retVal;
+ bool reuse = false;
+ if (!m_freeList.empty()) {
+ // try to re-use an existing block
+ for (auto freeElem : m_freeList) {
+ if (freeElem->second.m_size >= requiredSize) {
+ retVal = freeElem;
+ reuse = true;
+ // Element is not going to be free anymore
+ m_freeList.erase(freeElem);
+ break;
+ }
+ }
+ }
+ if (!reuse) {
+ retVal = std::prev(m_pointerMap.end());
+ }
+ return retVal;
+ }
+
+ /**
+ * Returns an iterator to the node that stores the information
+ * of the given virtual pointer from the given pointer map structure.
+ * If pointer is not found, throws std::out_of_range.
+ * If the pointer map structure is empty, throws std::out_of_range
+ *
+ * \param pMap the pointerMap_t structure storing all the pointers
+ * \param virtual_pointer_ptr The virtual pointer to obtain the node of
+ * \throws std::out:of_range if the pointer is not found or pMap is empty
+ */
+ typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
+ if (this->count() == 0) {
+ m_pointerMap.clear();
+ EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
+
+ }
+ if (is_nullptr(ptr)) {
+ m_pointerMap.clear();
+ EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
+ }
+ // The previous element to the lower bound is the node that
+ // holds this memory address
+ auto node = m_pointerMap.lower_bound(ptr);
+ // If the value of the pointer is not the one of the node
+ // then we return the previous one
+ if (node == std::end(m_pointerMap)) {
+ --node;
+ } else if (node->first != ptr) {
+ if (node == std::begin(m_pointerMap)) {
+ m_pointerMap.clear();
+ EIGEN_THROW_X(
+ std::out_of_range("The pointer is not registered in the map\n"));
+
+ }
+ --node;
+ }
+
+ return node;
+ }
+
+ /* get_buffer.
+ * Returns a buffer from the map using the pointer address
+ */
+ template <typename buffer_data_type = buffer_data_type_t>
+ cl::sycl::buffer<buffer_data_type, 1> get_buffer(
+ const virtual_pointer_t ptr) {
+ using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
+
+ // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
+ // We can do this without the `buffer_mem` being a pointer, as we
+ // only declare member variables in the base class (`buffer_mem`) and not in
+ // the child class (`buffer<>).
+ auto node = get_node(ptr);
+ eigen_assert(node->first == ptr || node->first < ptr);
+ eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
+ node->first));
+ return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
+ }
+
+ /**
+ * @brief Returns an accessor to the buffer of the given virtual pointer
+ * @param accessMode
+ * @param accessTarget
+ * @param ptr The virtual pointer
+ */
+ template <sycl_acc_mode access_mode = default_acc_mode,
+ sycl_acc_target access_target = default_acc_target,
+ typename buffer_data_type = buffer_data_type_t>
+ cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
+ get_access(const virtual_pointer_t ptr) {
+ auto buf = get_buffer<buffer_data_type>(ptr);
+ return buf.template get_access<access_mode, access_target>();
+ }
+
+ /**
+ * @brief Returns an accessor to the buffer of the given virtual pointer
+ * in the given command group scope
+ * @param accessMode
+ * @param accessTarget
+ * @param ptr The virtual pointer
+ * @param cgh Reference to the command group scope
+ */
+ template <sycl_acc_mode access_mode = default_acc_mode,
+ sycl_acc_target access_target = default_acc_target,
+ typename buffer_data_type = buffer_data_type_t>
+ cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
+ get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
+ auto buf = get_buffer<buffer_data_type>(ptr);
+ return buf.template get_access<access_mode, access_target>(cgh);
+ }
+
+ /*
+ * Returns the offset from the base address of this pointer.
+ */
+ inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
+ // The previous element to the lower bound is the node that
+ // holds this memory address
+ auto node = get_node(ptr);
+ auto start = node->first;
+ eigen_assert(start == ptr || start < ptr);
+ eigen_assert(ptr < start + node->second.m_size);
+ return (ptr - start);
+ }
+
+ /*
+ * Returns the number of elements by which the given pointer is offset from
+ * the base address.
+ */
+ template <typename buffer_data_type>
+ inline size_t get_element_offset(const virtual_pointer_t ptr) {
+ return get_offset(ptr) / sizeof(buffer_data_type);
+ }
+
+ /**
+ * Constructs the PointerMapper structure.
+ */
+ PointerMapper(base_ptr_t baseAddress = 4096)
+ : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
+ if (m_baseAddress == 0) {
+ EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
+ }
+ };
+
+ /**
+ * PointerMapper cannot be copied or moved
+ */
+ PointerMapper(const PointerMapper &) = delete;
+
+ /**
+ * Empty the pointer list
+ */
+ inline void clear() {
+ m_freeList.clear();
+ m_pointerMap.clear();
+ }
+
+ /* add_pointer.
+ * Adds an existing pointer to the map and returns the virtual pointer id.
+ */
+ inline virtual_pointer_t add_pointer(const buffer_t &b) {
+ return add_pointer_impl(b);
+ }
+
+ /* add_pointer.
+ * Adds a pointer to the map and returns the virtual pointer id.
+ */
+ inline virtual_pointer_t add_pointer(buffer_t &&b) {
+ return add_pointer_impl(b);
+ }
+
+ /**
+ * @brief Fuses the given node with the previous nodes in the
+ * pointer map if they are free
+ *
+ * @param node A reference to the free node to be fused
+ */
+ void fuse_forward(typename pointerMap_t::iterator &node) {
+ while (node != std::prev(m_pointerMap.end())) {
+ // if following node is free
+ // remove it and extend the current node with its size
+ auto fwd_node = std::next(node);
+ if (!fwd_node->second.m_free) {
+ break;
+ }
+ auto fwd_size = fwd_node->second.m_size;
+ m_freeList.erase(fwd_node);
+ m_pointerMap.erase(fwd_node);
+
+ node->second.m_size += fwd_size;
+ }
+ }
+
+ /**
+ * @brief Fuses the given node with the following nodes in the
+ * pointer map if they are free
+ *
+ * @param node A reference to the free node to be fused
+ */
+ void fuse_backward(typename pointerMap_t::iterator &node) {
+ while (node != m_pointerMap.begin()) {
+ // if previous node is free, extend it
+ // with the size of the current one
+ auto prev_node = std::prev(node);
+ if (!prev_node->second.m_free) {
+ break;
+ }
+ prev_node->second.m_size += node->second.m_size;
+
+ // remove the current node
+ m_freeList.erase(node);
+ m_pointerMap.erase(node);
+
+ // point to the previous node
+ node = prev_node;
+ }
+ }
+
+ /* remove_pointer.
+ * Removes the given pointer from the map.
+ * The pointer is allowed to be reused only if ReUse if true.
+ */
+ template <bool ReUse = true>
+ void remove_pointer(const virtual_pointer_t ptr) {
+ if (is_nullptr(ptr)) {
+ return;
+ }
+ auto node = this->get_node(ptr);
+
+ node->second.m_free = true;
+ m_freeList.emplace(node);
+
+ // Fuse the node
+ // with free nodes before and after it
+ fuse_forward(node);
+ fuse_backward(node);
+
+ // If after fusing the node is the last one
+ // simply remove it (since it is free)
+ if (node == std::prev(m_pointerMap.end())) {
+ m_freeList.erase(node);
+ m_pointerMap.erase(node);
+ }
+ }
+
+ /* count.
+ * Return the number of active pointers (i.e, pointers that
+ * have been malloc but not freed).
+ */
+ size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
+
+ private:
+ /* add_pointer_impl.
+ * Adds a pointer to the map and returns the virtual pointer id.
+ * BufferT is either a const buffer_t& or a buffer_t&&.
+ */
+ template <class BufferT>
+ virtual_pointer_t add_pointer_impl(BufferT b) {
+ virtual_pointer_t retVal = nullptr;
+ size_t bufSize = b.get_count();
+ pMapNode_t p{b, bufSize, false};
+ // If this is the first pointer:
+ if (m_pointerMap.empty()) {
+ virtual_pointer_t initialVal{m_baseAddress};
+ m_pointerMap.emplace(initialVal, p);
+ return initialVal;
+ }
+
+ auto lastElemIter = get_insertion_point(bufSize);
+ // We are recovering an existing free node
+ if (lastElemIter->second.m_free) {
+ lastElemIter->second.m_buffer = b;
+ lastElemIter->second.m_free = false;
+
+ // If the recovered node is bigger than the inserted one
+ // add a new free node with the remaining space
+ if (lastElemIter->second.m_size > bufSize) {
+ // create a new node with the remaining space
+ auto remainingSize = lastElemIter->second.m_size - bufSize;
+ pMapNode_t p2{b, remainingSize, true};
+
+ // update size of the current node
+ lastElemIter->second.m_size = bufSize;
+
+ // add the new free node
+ auto newFreePtr = lastElemIter->first + bufSize;
+ auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
+ m_freeList.emplace(freeNode);
+ }
+
+ retVal = lastElemIter->first;
+ } else {
+ size_t lastSize = lastElemIter->second.m_size;
+ retVal = lastElemIter->first + lastSize;
+ m_pointerMap.emplace(retVal, p);
+ }
+ return retVal;
+ }
+
+ /**
+ * Compare two iterators to pointer map entries according to
+ * the size of the allocation on the device.
+ */
+ struct SortBySize {
+ bool operator()(typename pointerMap_t::iterator a,
+ typename pointerMap_t::iterator b) const {
+ return ((a->first < b->first) && (a->second <= b->second)) ||
+ ((a->first < b->first) && (b->second <= a->second));
+ }
+ };
+
+ /* Maps the pointer addresses to buffer and size pairs.
+ */
+ pointerMap_t m_pointerMap;
+
+ /* List of free nodes available for re-using
+ */
+ std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
+
+ /* Base address used when issuing the first virtual pointer, allows users
+ * to specify alignment. Cannot be zero. */
+ std::intptr_t m_baseAddress;
+};
+
+/* remove_pointer.
+ * Removes the given pointer from the map.
+ * The pointer is allowed to be reused only if ReUse if true.
+ */
+template <>
+inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
+ if (is_nullptr(ptr)) {
+ return;
+ }
+ m_pointerMap.erase(this->get_node(ptr));
+}
+
+/**
+ * Malloc-like interface to the pointer-mapper.
+ * Given a size, creates a byte-typed buffer and returns a
+ * fake pointer to keep track of it.
+ * \param size Size in bytes of the desired allocation
+ * \throw cl::sycl::exception if error while creating the buffer
+ */
+inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
+ if (size == 0) {
+ return nullptr;
+ }
+ // Create a generic buffer of the given size
+ using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
+ auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
+ // Store the buffer on the global list
+ return static_cast<void *>(thePointer);
+}
+
+/**
+ * Free-like interface to the pointer mapper.
+ * Given a fake-pointer created with the virtual-pointer malloc,
+ * destroys the buffer and remove it from the list.
+ * If ReUse is false, the pointer is not added to the freeList,
+ * it should be false only for sub-buffers.
+ */
+template <bool ReUse = true, typename PointerMapper>
+inline void SYCLfree(void *ptr, PointerMapper &pMap) {
+ pMap.template remove_pointer<ReUse>(ptr);
+}
+
+/**
+ * Clear all the memory allocated by SYCL.
+ */
+template <typename PointerMapper>
+inline void SYCLfreeAll(PointerMapper &pMap) {
+ pMap.clear();
+}
+
+template <cl::sycl::access::mode AcMd, typename T>
+struct RangeAccess {
+ static const auto global_access = cl::sycl::access::target::global_buffer;
+ static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
+ typedef T scalar_t;
+ typedef scalar_t &ref_t;
+ typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
+
+ // the accessor type does not necessarily the same as T
+ typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
+ accessor;
+
+ typedef RangeAccess<AcMd, T> self_t;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
+ size_t offset,
+ std::intptr_t virtual_ptr)
+ : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
+
+ RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
+ cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
+ : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
+
+ // This should be only used for null constructor on the host side
+ RangeAccess(std::nullptr_t) : RangeAccess() {}
+ // This template parameter must be removed and scalar_t should be replaced
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
+ return (access_.get_pointer().get() + offset_);
+ }
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
+ offset_ += (offset);
+ return *this;
+ }
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
+ return self_t(access_, offset_ + offset, virtual_ptr_);
+ }
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
+ return self_t(access_, offset_ - offset, virtual_ptr_);
+ }
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
+ offset_ -= offset;
+ return *this;
+ }
+
+ // THIS IS FOR NULL COMPARISON ONLY
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
+ const RangeAccess &lhs, std::nullptr_t) {
+ return ((lhs.virtual_ptr_ == -1));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
+ const RangeAccess &lhs, std::nullptr_t i) {
+ return !(lhs == i);
+ }
+
+ // THIS IS FOR NULL COMPARISON ONLY
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
+ std::nullptr_t, const RangeAccess &rhs) {
+ return ((rhs.virtual_ptr_ == -1));
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
+ std::nullptr_t i, const RangeAccess &rhs) {
+ return !(i == rhs);
+ }
+ // Prefix operator (Increment and return value)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
+ offset_++;
+ return (*this);
+ }
+
+ // Postfix operator (Return value and increment)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
+ EIGEN_UNUSED_VARIABLE(i);
+ self_t temp_iterator(*this);
+ offset_++;
+ return temp_iterator;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
+ return (access_.get_count() - offset_);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
+ return offset_;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
+ offset_ = offset;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
+ return *get_pointer();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
+ return *get_pointer();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
+ return *(get_pointer() + x);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
+ return *(get_pointer() + x);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
+ return reinterpret_cast<scalar_t *>(virtual_ptr_ +
+ (offset_ * sizeof(scalar_t)));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
+ return (virtual_ptr_ != -1);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
+ return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ operator RangeAccess<AcMd, const T>() const {
+ return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
+ }
+ // binding placeholder accessors to a command group handler for SYCL
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
+ cl::sycl::handler &cgh) const {
+ cgh.require(access_);
+ }
+
+ private:
+ accessor access_;
+ size_t offset_;
+ std::intptr_t virtual_ptr_; // the location of the buffer in the map
+};
+
+template <cl::sycl::access::mode AcMd, typename T>
+struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
+ typedef RangeAccess<AcMd, T> Base;
+ using Base::Base;
+};
+
+} // namespace internal
+} // namespace TensorSycl
+} // namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
diff --git a/Eigen/src/Core/arch/SYCL/TypeCasting.h b/Eigen/src/Core/arch/SYCL/TypeCasting.h
new file mode 100644
index 000000000..9208ab21d
--- /dev/null
+++ b/Eigen/src/Core/arch/SYCL/TypeCasting.h
@@ -0,0 +1,85 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * TypeCasting.h
+ *
+ * \brief:
+ * TypeCasting
+ *
+ *****************************************************************/
+
+#ifndef EIGEN_TYPE_CASTING_SYCL_H
+#define EIGEN_TYPE_CASTING_SYCL_H
+
+namespace Eigen {
+
+namespace internal {
+#ifdef SYCL_DEVICE_ONLY
+template <>
+struct type_casting_traits<float, int> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_int4
+pcast<cl::sycl::cl_float4, cl::sycl::cl_int4>(const cl::sycl::cl_float4& a) {
+ return a
+ .template convert<cl::sycl::cl_int, cl::sycl::rounding_mode::automatic>();
+}
+
+template <>
+struct type_casting_traits<int, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
+pcast<cl::sycl::cl_int4, cl::sycl::cl_float4>(const cl::sycl::cl_int4& a) {
+ return a.template convert<cl::sycl::cl_float,
+ cl::sycl::rounding_mode::automatic>();
+}
+
+template <>
+struct type_casting_traits<double, float> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 };
+};
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
+pcast<cl::sycl::cl_double2, cl::sycl::cl_float4>(
+ const cl::sycl::cl_double2& a, const cl::sycl::cl_double2& b) {
+ auto a1 = a.template convert<cl::sycl::cl_float,
+ cl::sycl::rounding_mode::automatic>();
+ auto b1 = b.template convert<cl::sycl::cl_float,
+ cl::sycl::rounding_mode::automatic>();
+ return cl::sycl::float4(a1.x(), a1.y(), b1.x(), b1.y());
+}
+
+template <>
+struct type_casting_traits<float, double> {
+ enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 };
+};
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2
+pcast<cl::sycl::cl_float4, cl::sycl::cl_double2>(const cl::sycl::cl_float4& a) {
+ // Simply discard the second half of the input
+ return cl::sycl::cl_double2(a.x(), a.y());
+}
+
+#endif
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_SYCL_H
diff --git a/Eigen/src/Core/arch/ZVector/Complex.h b/Eigen/src/Core/arch/ZVector/Complex.h
index d39d2d105..0b9b33d99 100644
--- a/Eigen/src/Core/arch/ZVector/Complex.h
+++ b/Eigen/src/Core/arch/ZVector/Complex.h
@@ -15,6 +15,10 @@ namespace Eigen {
namespace internal {
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+static Packet4ui p4ui_CONJ_XOR = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 }; //vec_mergeh((Packet4ui)p4i_ZERO, (Packet4ui)p4f_MZERO);
+#endif
+
static Packet2ul p2ul_CONJ_XOR1 = (Packet2ul) vec_sld((Packet4ui) p2d_ZERO_, (Packet4ui) p2l_ZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 };
static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2l_ZERO, (Packet4ui) p2d_ZERO_, 8);//{ 0x8000000000000000, 0x0000000000000000 };
@@ -29,10 +33,14 @@ struct Packet2cf
{
EIGEN_STRONG_INLINE Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {}
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12)
union {
Packet4f v;
Packet1cd cd[2];
};
+#else
+ Packet4f v;
+#endif
};
template<> struct packet_traits<std::complex<float> > : default_packet_traits
@@ -83,69 +91,33 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; };
/* Forward declaration */
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel);
-template<> EIGEN_STRONG_INLINE Packet2cf pload <Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from)); }
+/* complex<double> first */
template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from)); }
-template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from)); }
template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from)); }
-template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
-template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
{ /* here we really have to use unaligned loads :( */ return ploadu<Packet1cd>(&from); }
-template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
-{
- Packet2cf res;
- res.cd[0] = Packet1cd(vec_ld2f((const float *)&from));
- res.cd[1] = res.cd[0];
- return res;
-}
-template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
-{
- std::complex<float> EIGEN_ALIGN16 af[2];
- af[0] = from[0*stride];
- af[1] = from[1*stride];
- return pload<Packet2cf>(af);
-}
template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index stride EIGEN_UNUSED)
{
return pload<Packet1cd>(from);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
-{
- std::complex<float> EIGEN_ALIGN16 af[2];
- pstore<std::complex<float> >((std::complex<float> *) af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
-}
template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index stride EIGEN_UNUSED)
{
pstore<std::complex<double> >(to, from);
}
-
-template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(padd<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v + b.v); }
-template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(psub<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v - b.v); }
template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); }
-template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(Packet4f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd((Packet2d)vec_xor((Packet2d)a.v, (Packet2d)p2ul_CONJ_XOR2)); }
-template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a)
-{
- Packet2cf res;
- res.v.v4f[0] = pconj(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[0]))).v;
- res.v.v4f[1] = pconj(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[1]))).v;
- return res;
-}
-
template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
Packet2d a_re, a_im, v1, v2;
@@ -163,27 +135,17 @@ template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, con
return Packet1cd(v1 + v2);
}
-template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- Packet2cf res;
- res.v.v4f[0] = pmul(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[0])), Packet1cd(reinterpret_cast<Packet2d>(b.v.v4f[0]))).v;
- res.v.v4f[1] = pmul(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[1])), Packet1cd(reinterpret_cast<Packet2d>(b.v.v4f[1]))).v;
- return res;
-}
-
-template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand<Packet4f>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_or(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por<Packet4f>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_xor(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor<Packet4f>(a.v,b.v)); }
-template<> EIGEN_STRONG_INLINE Packet1cd pandnot<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v, vec_nor(b.v,b.v))); }
-template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pandnot<Packet4f>(a.v,b.v)); }
-
+template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_or(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_xor(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet1cd pandnot <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v, vec_nor(b.v,b.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* from) { return pset1<Packet1cd>(*from); }
-template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) {
+ Packet2d eq = vec_cmpeq (a.v, b.v);
+ Packet2d tmp = { eq[1], eq[0] };
+ return (Packet1cd)pand<Packet2d>(eq, tmp);
+}
-template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::complex<double> * addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
@@ -193,157 +155,157 @@ template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Pac
return res;
}
-template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
-{
- std::complex<float> EIGEN_ALIGN16 res[2];
- pstore<std::complex<float> >(res, a);
-
- return res[0];
-}
template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
+template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a)
{
- Packet2cf res;
- res.cd[0] = a.cd[1];
- res.cd[1] = a.cd[0];
- return res;
+ return pfirst(a);
}
-
-template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a)
+template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a)
{
return pfirst(a);
}
-template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a)
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
+
+template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
- std::complex<float> res;
- Packet1cd b = padd<Packet1cd>(a.cd[0], a.cd[1]);
- vec_st2f(b.v, (float*)&res);
- return res;
+ // TODO optimize it for AltiVec
+ Packet1cd res = pmul(a,pconj(b));
+ Packet2d s = vec_madd(b.v, b.v, p2d_ZERO_);
+ return Packet1cd(pdiv(res.v, s + vec_perm(s, s, p16uc_REVERSE64)));
}
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs)
+EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
{
- return vecs[0];
+ return Packet1cd(preverse(Packet2d(x.v)));
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
{
- PacketBlock<Packet2cf,2> transpose;
- transpose.packet[0] = vecs[0];
- transpose.packet[1] = vecs[1];
- ptranspose(transpose);
+ Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI);
+ kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
+ kernel.packet[0].v = tmp;
+}
- return padd<Packet2cf>(transpose.packet[0], transpose.packet[1]);
-}
+/* complex<float> follows */
+template<> EIGEN_STRONG_INLINE Packet2cf pload <Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE Packet2cf ploadu<Packet2cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu<Packet4f>((const float*)from)); }
+template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); }
+template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); }
-template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a)
+template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
{
- return pfirst(a);
+ std::complex<float> EIGEN_ALIGN16 res[2];
+ pstore<std::complex<float> >(res, a);
+
+ return res[0];
}
-template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
+
+
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12)
+template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
- std::complex<float> res;
- Packet1cd b = pmul<Packet1cd>(a.cd[0], a.cd[1]);
- vec_st2f(b.v, (float*)&res);
+ Packet2cf res;
+ res.cd[0] = Packet1cd(vec_ld2f((const float *)&from));
+ res.cd[1] = res.cd[0];
return res;
}
-
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
+#else
+template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
+ Packet2cf res;
+ if((std::ptrdiff_t(&from) % 16) == 0)
+ res.v = pload<Packet4f>((const float *)&from);
+ else
+ res.v = ploadu<Packet4f>((const float *)&from);
+ res.v = vec_perm(res.v, res.v, p16uc_PSET64_HI);
+ return res;
+}
+#endif
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
+template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
{
- static EIGEN_STRONG_INLINE void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset == 1) {
- first.cd[0] = first.cd[1];
- first.cd[1] = second.cd[0];
- }
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
+ std::complex<float> EIGEN_ALIGN16 af[2];
+ af[0] = from[0*stride];
+ af[1] = from[1*stride];
+ return pload<Packet2cf>(af);
+}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
+ std::complex<float> EIGEN_ALIGN16 af[2];
+ pstore<std::complex<float> >((std::complex<float> *) af, from);
+ to[0*stride] = af[0];
+ to[1*stride] = af[1];
+}
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
+template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(padd<Packet4f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(psub<Packet4f>(a.v, b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(Packet4f(a.v))); }
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
+template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand<Packet4f>(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por<Packet4f>(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor<Packet4f>(a.v,b.v)); }
+template<> EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pandnot<Packet4f>(a.v,b.v)); }
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
+template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* from) { return pset1<Packet2cf>(*from); }
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
+template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::complex<float> * addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12)
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) {
+ Packet4f eq = pcmp_eq<Packet4f> (a.v, b.v);
+ Packet2cf res;
+ Packet2d tmp1 = { eq.v4f[0][1], eq.v4f[0][0] };
+ Packet2d tmp2 = { eq.v4f[1][1], eq.v4f[1][0] };
+ res.v.v4f[0] = pand<Packet2d>(eq.v4f[0], tmp1);
+ res.v.v4f[1] = pand<Packet2d>(eq.v4f[1], tmp2);
+ return res;
+}
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
+template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a)
{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
+ Packet2cf res;
+ res.v.v4f[0] = pconj(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[0]))).v;
+ res.v.v4f[1] = pconj(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[1]))).v;
+ return res;
+}
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
+template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{
+ Packet2cf res;
+ res.v.v4f[0] = pmul(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[0])), Packet1cd(reinterpret_cast<Packet2d>(b.v.v4f[0]))).v;
+ res.v.v4f[1] = pmul(Packet1cd(reinterpret_cast<Packet2d>(a.v.v4f[1])), Packet1cd(reinterpret_cast<Packet2d>(b.v.v4f[1]))).v;
+ return res;
+}
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
+template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
+ Packet2cf res;
+ res.cd[0] = a.cd[1];
+ res.cd[1] = a.cd[0];
+ return res;
+}
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
+template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a)
+{
+ std::complex<float> res;
+ Packet1cd b = padd<Packet1cd>(a.cd[0], a.cd[1]);
+ vec_st2f(b.v, (float*)&res);
+ return res;
+}
-template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
+template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
- // TODO optimize it for AltiVec
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
- Packet2d s = vec_madd(b.v, b.v, p2d_ZERO_);
- return Packet1cd(pdiv(res.v, s + vec_perm(s, s, p16uc_REVERSE64)));
+ std::complex<float> res;
+ Packet1cd b = pmul<Packet1cd>(a.cd[0], a.cd[1]);
+ vec_st2f(b.v, (float*)&res);
+ return res;
}
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
+
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
@@ -353,11 +315,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, con
return res;
}
-EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
-{
- return Packet1cd(preverse(Packet2d(x.v)));
-}
-
EIGEN_STRONG_INLINE Packet2cf pcplxflip/*<Packet2cf>*/(const Packet2cf& x)
{
Packet2cf res;
@@ -366,13 +323,6 @@ EIGEN_STRONG_INLINE Packet2cf pcplxflip/*<Packet2cf>*/(const Packet2cf& x)
return res;
}
-EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
-{
- Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI);
- kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
- kernel.packet[0].v = tmp;
-}
-
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
{
Packet1cd tmp = kernel.packet[0].cd[1];
@@ -386,6 +336,88 @@ template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, con
result.v = pblend<Packet4f>(ifPacket4, thenPacket.v, elsePacket.v);
return result;
}
+#else
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) {
+ Packet4f eq = vec_cmpeq (a.v, b.v);
+ Packet4f tmp = { eq[1], eq[0], eq[3], eq[2] };
+ return (Packet2cf)pand<Packet4f>(eq, tmp);
+}
+template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor<Packet4f>(a.v, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR))); }
+template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{
+ Packet4f a_re, a_im, prod, prod_im;
+
+ // Permute and multiply the real parts of a and b
+ a_re = vec_perm(a.v, a.v, p16uc_PSET32_WODD);
+
+ // Get the imaginary parts of a
+ a_im = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN);
+
+ // multiply a_im * b and get the conjugate result
+ prod_im = a_im * b.v;
+ prod_im = pxor<Packet4f>(prod_im, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR));
+ // permute back to a proper order
+ prod_im = vec_perm(prod_im, prod_im, p16uc_COMPLEX32_REV);
+
+ // multiply a_re * b, add prod_im
+ prod = pmadd<Packet4f>(a_re, b.v, prod_im);
+
+ return Packet2cf(prod);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
+{
+ Packet4f rev_a;
+ rev_a = vec_perm(a.v, a.v, p16uc_COMPLEX32_REV2);
+ return Packet2cf(rev_a);
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packet2cf& a)
+{
+ Packet4f b;
+ b = vec_sld(a.v, a.v, 8);
+ b = padd<Packet4f>(a.v, b);
+ return pfirst<Packet2cf>(Packet2cf(b));
+}
+
+template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
+{
+ Packet4f b;
+ Packet2cf prod;
+ b = vec_sld(a.v, a.v, 8);
+ prod = pmul<Packet2cf>(a, Packet2cf(b));
+
+ return pfirst<Packet2cf>(prod);
+}
+
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
+
+template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
+{
+ // TODO optimize it for AltiVec
+ Packet2cf res = pmul(a, pconj(b));
+ Packet4f s = pmul<Packet4f>(b.v, b.v);
+ return Packet2cf(pdiv(res.v, padd<Packet4f>(s, vec_perm(s, s, p16uc_COMPLEX32_REV))));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& x)
+{
+ return Packet2cf(vec_perm(x.v, x.v, p16uc_COMPLEX32_REV));
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
+{
+ Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI);
+ kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
+ kernel.packet[0].v = tmp;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
+ Packet2cf result;
+ result.v = reinterpret_cast<Packet4f>(pblend<Packet2d>(ifPacket, reinterpret_cast<Packet2d>(thenPacket.v), reinterpret_cast<Packet2d>(elsePacket.v)));
+ return result;
+}
+#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/ZVector/MathFunctions.h b/Eigen/src/Core/arch/ZVector/MathFunctions.h
index 5c7aa7256..1635e128c 100644
--- a/Eigen/src/Core/arch/ZVector/MathFunctions.h
+++ b/Eigen/src/Core/arch/ZVector/MathFunctions.h
@@ -20,6 +20,50 @@ namespace Eigen {
namespace internal {
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+static _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
+static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
+static _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
+static _EIGEN_DECLARE_CONST_Packet4i(23, 23);
+
+static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000);
+
+/* the smallest non denormalized float number */
+static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000);
+static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000); // -1.f/0.f
+static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_nan, 0xffffffff);
+
+/* natural logarithm computed for 4 simultaneous float
+ return NaN for x <= 0
+*/
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
+
+static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
+static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
+
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
+
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
+static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
+#endif
+
static _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0);
static _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0);
static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
@@ -93,43 +137,95 @@ Packet2d pexp<Packet2d>(const Packet2d& _x)
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4f pexp<Packet4f>(const Packet4f& x)
+Packet4f pexp<Packet4f>(const Packet4f& _x)
{
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+ Packet4f x = _x;
+
+ Packet4f tmp, fx;
+ Packet4i emm0;
+
+ // clamp x
+ x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
+
+ // express exp(x) as exp(g + n*log(2))
+ fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
+
+ fx = pfloor(fx);
+
+ tmp = pmul(fx, p4f_cephes_exp_C1);
+ Packet4f z = pmul(fx, p4f_cephes_exp_C2);
+ x = psub(x, tmp);
+ x = psub(x, z);
+
+ z = pmul(x,x);
+
+ Packet4f y = p4f_cephes_exp_p0;
+ y = pmadd(y, x, p4f_cephes_exp_p1);
+ y = pmadd(y, x, p4f_cephes_exp_p2);
+ y = pmadd(y, x, p4f_cephes_exp_p3);
+ y = pmadd(y, x, p4f_cephes_exp_p4);
+ y = pmadd(y, x, p4f_cephes_exp_p5);
+ y = pmadd(y, z, x);
+ y = padd(y, p4f_1);
+
+ // build 2^n
+ emm0 = (Packet4i){ (int)fx[0], (int)fx[1], (int)fx[2], (int)fx[3] };
+ emm0 = emm0 + p4i_0x7f;
+ emm0 = emm0 << reinterpret_cast<Packet4i>(p4i_23);
+
+ return pmax(pmul(y, reinterpret_cast<Packet4f>(emm0)), _x);
+#else
Packet4f res;
- res.v4f[0] = pexp<Packet2d>(x.v4f[0]);
- res.v4f[1] = pexp<Packet2d>(x.v4f[1]);
+ res.v4f[0] = pexp<Packet2d>(_x.v4f[0]);
+ res.v4f[1] = pexp<Packet2d>(_x.v4f[1]);
return res;
+#endif
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d psqrt<Packet2d>(const Packet2d& x)
{
- return __builtin_s390_vfsqdb(x);
+ return vec_sqrt(x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& x)
{
Packet4f res;
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+ res = vec_sqrt(x);
+#else
res.v4f[0] = psqrt<Packet2d>(x.v4f[0]);
res.v4f[1] = psqrt<Packet2d>(x.v4f[1]);
+#endif
return res;
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return pset1<Packet2d>(1.0) / psqrt<Packet2d>(x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
Packet4f res;
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+ res = pset1<Packet4f>(1.0) / psqrt<Packet4f>(x);
+#else
res.v4f[0] = prsqrt<Packet2d>(x.v4f[0]);
res.v4f[1] = prsqrt<Packet2d>(x.v4f[1]);
+#endif
return res;
}
+// Hyperbolic Tangent function.
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+ptanh<Packet4f>(const Packet4f& x) {
+ return internal::generic_fast_tanh_float(x);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/ZVector/PacketMath.h b/Eigen/src/Core/arch/ZVector/PacketMath.h
index 57b01fc63..1f55a90a5 100755
--- a/Eigen/src/Core/arch/ZVector/PacketMath.h
+++ b/Eigen/src/Core/arch/ZVector/PacketMath.h
@@ -10,26 +10,20 @@
#ifndef EIGEN_PACKET_MATH_ZVECTOR_H
#define EIGEN_PACKET_MATH_ZVECTOR_H
-#include <stdint.h>
-
namespace Eigen {
namespace internal {
#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
-#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 4
+#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 16
#endif
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
-#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#endif
-
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
-#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16
+#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
typedef __vector int Packet4i;
@@ -41,21 +35,30 @@ typedef __vector double Packet2d;
typedef __vector unsigned long long Packet2ul;
typedef __vector long long Packet2l;
+// Z14 has builtin support for float vectors
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+typedef __vector float Packet4f;
+#else
typedef struct {
Packet2d v4f[2];
} Packet4f;
+#endif
typedef union {
- int32_t i[4];
- uint32_t ui[4];
- int64_t l[2];
- uint64_t ul[2];
+ numext::int32_t i[4];
+ numext::uint32_t ui[4];
+ numext::int64_t l[2];
+ numext::uint64_t ul[2];
double d[2];
+ float f[4];
Packet4i v4i;
Packet4ui v4ui;
Packet2l v2l;
Packet2ul v2ul;
Packet2d v2d;
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+ Packet4f v4f;
+#endif
} Packet;
// We don't want to write the same code all the time, but we need to reuse the constants
@@ -80,15 +83,31 @@ typedef union {
Packet2l p2l_##NAME = pset1<Packet2l>(X)
// These constants are endian-agnostic
-//static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,}
+static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(ONE, 1); //{ 1, 1, 1, 1}
static _EIGEN_DECLARE_CONST_FAST_Packet2d(ZERO, 0);
static _EIGEN_DECLARE_CONST_FAST_Packet2l(ZERO, 0);
static _EIGEN_DECLARE_CONST_FAST_Packet2l(ONE, 1);
-static Packet2d p2d_ONE = { 1.0, 1.0 };
-static Packet2d p2d_ZERO_ = { -0.0, -0.0 };
+static Packet2d p2d_ONE = { 1.0, 1.0 };
+static Packet2d p2d_ZERO_ = { numext::bit_cast<double>0x8000000000000000ull),
+ numext::bit_cast<double>0x8000000000000000ull) };
+
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
+ Packet4f p4f_##NAME = reinterpret_cast<Packet4f>(vec_splat_s32(X))
+
+#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
+ Packet4f p4f_##NAME = pset1<Packet4f>(X)
+
+#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \
+ const Packet4f p4f_##NAME = reinterpret_cast<Packet4f>(pset1<Packet4i>(X))
+
+static _EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0}
+static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
+static Packet4f p4f_MZERO = { 0x80000000, 0x80000000, 0x80000000, 0x80000000};
+#endif
static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
@@ -120,9 +139,9 @@ static Packet16uc p16uc_TRANSPOSE64_LO = vec_add(p16uc_PSET64_LO, p16uc_HALF64_0
static Packet16uc p16uc_TRANSPOSE64_HI = { 0,1,2,3, 4,5,6,7, 16,17,18,19, 20,21,22,23};
static Packet16uc p16uc_TRANSPOSE64_LO = { 8,9,10,11, 12,13,14,15, 24,25,26,27, 28,29,30,31};
-//static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 };
+static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 };
-//static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
+static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
#if EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC
@@ -149,29 +168,31 @@ template<> struct packet_traits<int> : default_packet_traits
};
};
-template<> struct packet_traits<float> : default_packet_traits
-{
+template <>
+struct packet_traits<float> : default_packet_traits {
typedef Packet4f type;
typedef Packet4f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
+ size = 4,
HasHalfPacket = 0,
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasDiv = 1,
- HasMin = 1,
- HasMax = 1,
- HasAbs = 1,
- HasSin = 0,
- HasCos = 0,
- HasLog = 0,
- HasExp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasAbs = 1,
+ HasSin = 0,
+ HasCos = 0,
+ HasLog = 0,
+ HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
+ HasTanh = 1,
+ HasErf = 1,
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
@@ -211,9 +232,9 @@ template<> struct packet_traits<double> : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4f half; };
+template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2d half; };
/* Forward declaration */
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4f,4>& kernel);
@@ -258,82 +279,15 @@ inline std::ostream & operator <<(std::ostream & s, const Packet2d & v)
return s;
}
-/* Helper function to simulate a vec_splat_packet4f
- */
-template<int element> EIGEN_STRONG_INLINE Packet4f vec_splat_packet4f(const Packet4f& from)
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12)
+inline std::ostream & operator <<(std::ostream & s, const Packet4f & v)
{
- Packet4f splat;
- switch (element) {
- case 0:
- splat.v4f[0] = vec_splat(from.v4f[0], 0);
- splat.v4f[1] = splat.v4f[0];
- break;
- case 1:
- splat.v4f[0] = vec_splat(from.v4f[0], 1);
- splat.v4f[1] = splat.v4f[0];
- break;
- case 2:
- splat.v4f[0] = vec_splat(from.v4f[1], 0);
- splat.v4f[1] = splat.v4f[0];
- break;
- case 3:
- splat.v4f[0] = vec_splat(from.v4f[1], 1);
- splat.v4f[1] = splat.v4f[0];
- break;
- }
- return splat;
+ Packet vt;
+ vt.v4f = v;
+ s << vt.f[0] << ", " << vt.f[1] << ", " << vt.f[2] << ", " << vt.f[3];
+ return s;
}
-
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
-{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
- switch (Offset % 4) {
- case 1:
- first = vec_sld(first, second, 4); break;
- case 2:
- first = vec_sld(first, second, 8); break;
- case 3:
- first = vec_sld(first, second, 12); break;
- }
- }
-};
-
-/* This is a tricky one, we have to translate float alignment to vector elements of sizeof double
- */
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
-{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
- switch (Offset % 4) {
- case 1:
- first.v4f[0] = vec_sld(first.v4f[0], first.v4f[1], 8);
- first.v4f[1] = vec_sld(first.v4f[1], second.v4f[0], 8);
- break;
- case 2:
- first.v4f[0] = first.v4f[1];
- first.v4f[1] = second.v4f[0];
- break;
- case 3:
- first.v4f[0] = vec_sld(first.v4f[1], second.v4f[0], 8);
- first.v4f[1] = vec_sld(second.v4f[0], second.v4f[1], 8);
- break;
- }
- }
-};
-
-
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset == 1)
- first = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(first), reinterpret_cast<Packet4i>(second), 8));
- }
-};
+#endif
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from)
{
@@ -344,16 +298,6 @@ template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from)
return vfrom->v4i;
}
-template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
-{
- // FIXME: No intrinsic yet
- EIGEN_DEBUG_ALIGNED_LOAD
- Packet4f vfrom;
- vfrom.v4f[0] = vec_ld2f(&from[0]);
- vfrom.v4f[1] = vec_ld2f(&from[2]);
- return vfrom;
-}
-
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from)
{
// FIXME: No intrinsic yet
@@ -372,15 +316,6 @@ template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& f
vto->v4i = from;
}
-template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
-{
- // FIXME: No intrinsic yet
- EIGEN_DEBUG_ALIGNED_STORE
- vec_st2f(from.v4f[0], &to[0]);
- vec_st2f(from.v4f[1], &to[2]);
-}
-
-
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from)
{
// FIXME: No intrinsic yet
@@ -397,13 +332,6 @@ template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from)
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
return vec_splats(from);
}
-template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from)
-{
- Packet4f to;
- to.v4f[0] = pset1<Packet2d>(static_cast<const double&>(from));
- to.v4f[1] = to.v4f[0];
- return to;
-}
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet4i>(const int *a,
@@ -417,17 +345,6 @@ pbroadcast4<Packet4i>(const int *a,
}
template<> EIGEN_STRONG_INLINE void
-pbroadcast4<Packet4f>(const float *a,
- Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
-{
- a3 = pload<Packet4f>(a);
- a0 = vec_splat_packet4f<0>(a3);
- a1 = vec_splat_packet4f<1>(a3);
- a2 = vec_splat_packet4f<2>(a3);
- a3 = vec_splat_packet4f<3>(a3);
-}
-
-template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet2d>(const double *a,
Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
@@ -449,16 +366,6 @@ template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int, Packet4i>(const int* f
return pload<Packet4i>(ai);
}
-template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
-{
- float EIGEN_ALIGN16 ai[4];
- ai[0] = from[0*stride];
- ai[1] = from[1*stride];
- ai[2] = from[2*stride];
- ai[3] = from[3*stride];
- return pload<Packet4f>(ai);
-}
-
template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride)
{
double EIGEN_ALIGN16 af[2];
@@ -477,6 +384,269 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<int, Packet4i>(int* to, const
to[3*stride] = ai[3];
}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
+{
+ double EIGEN_ALIGN16 af[2];
+ pstore<double>(af, from);
+ to[0*stride] = af[0];
+ to[1*stride] = af[1];
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a + b); }
+template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a + b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a - b); }
+template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a - b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a * b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmul<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a * b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a / b); }
+template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a / b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return (-a); }
+template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return (-a); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd<Packet4i>(pmul<Packet4i>(a, b), c); }
+template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); }
+
+template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return padd<Packet4i>(pset1<Packet4i>(a), p4i_COUNTDOWN); }
+template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) { return padd<Packet2d>(pset1<Packet2d>(a), p2d_COUNTDOWN); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_min(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_max(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_or(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_xor(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return pand<Packet4i>(a, vec_nor(b, b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return vec_ceil(a); }
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return vec_floor(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from) { return pload<Packet4i>(from); }
+template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from) { return pload<Packet2d>(from); }
+
+
+template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int* from)
+{
+ Packet4i p = pload<Packet4i>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE32_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
+{
+ Packet2d p = pload<Packet2d>(from);
+ return vec_perm(p, p, p16uc_PSET64_HI);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { pstore<int>(to, from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { pstore<double>(to, from); }
+
+template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
+
+template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int EIGEN_ALIGN16 x[4]; pstore(x, a); return x[0]; }
+template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { double EIGEN_ALIGN16 x[2]; pstore(x, a); return x[0]; }
+
+template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
+{
+ return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
+{
+ return reinterpret_cast<Packet2d>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE64));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pabs<Packet4i>(const Packet4i& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet2d pabs<Packet2d>(const Packet2d& a) { return vec_abs(a); }
+
+template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
+{
+ Packet4i b, sum;
+ b = vec_sld(a, a, 8);
+ sum = padd<Packet4i>(a, b);
+ b = vec_sld(sum, sum, 4);
+ sum = padd<Packet4i>(sum, b);
+ return pfirst(sum);
+}
+
+template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
+{
+ Packet2d b, sum;
+ b = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8));
+ sum = padd<Packet2d>(a, b);
+ return pfirst(sum);
+}
+
+// Other reduction functions:
+// mul
+template<> EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a)
+{
+ EIGEN_ALIGN16 int aux[4];
+ pstore(aux, a);
+ return aux[0] * aux[1] * aux[2] * aux[3];
+}
+
+template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
+{
+ return pfirst(pmul(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
+}
+
+// min
+template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
+{
+ Packet4i b, res;
+ b = pmin<Packet4i>(a, vec_sld(a, a, 8));
+ res = pmin<Packet4i>(b, vec_sld(b, b, 4));
+ return pfirst(res);
+}
+
+template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a)
+{
+ return pfirst(pmin<Packet2d>(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
+}
+
+// max
+template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
+{
+ Packet4i b, res;
+ b = pmax<Packet4i>(a, vec_sld(a, a, 8));
+ res = pmax<Packet4i>(b, vec_sld(b, b, 4));
+ return pfirst(res);
+}
+
+// max
+template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
+{
+ return pfirst(pmax<Packet2d>(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet4i,4>& kernel) {
+ Packet4i t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ Packet4i t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ Packet4i t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ Packet4i t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet2d,2>& kernel) {
+ Packet2d t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI);
+ Packet2d t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO);
+ kernel.packet[0] = t0;
+ kernel.packet[1] = t1;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
+ Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
+ Packet4ui mask = vec_cmpeq(select, reinterpret_cast<Packet4ui>(p4i_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
+ Packet2ul select = { ifPacket.select[0], ifPacket.select[1] };
+ Packet2ul mask = vec_cmpeq(select, reinterpret_cast<Packet2ul>(p2l_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+/* z13 has no vector float support so we emulate that with double
+ z14 has proper vector float support.
+*/
+#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12)
+/* Helper function to simulate a vec_splat_packet4f
+ */
+template<int element> EIGEN_STRONG_INLINE Packet4f vec_splat_packet4f(const Packet4f& from)
+{
+ Packet4f splat;
+ switch (element) {
+ case 0:
+ splat.v4f[0] = vec_splat(from.v4f[0], 0);
+ splat.v4f[1] = splat.v4f[0];
+ break;
+ case 1:
+ splat.v4f[0] = vec_splat(from.v4f[0], 1);
+ splat.v4f[1] = splat.v4f[0];
+ break;
+ case 2:
+ splat.v4f[0] = vec_splat(from.v4f[1], 0);
+ splat.v4f[1] = splat.v4f[0];
+ break;
+ case 3:
+ splat.v4f[0] = vec_splat(from.v4f[1], 1);
+ splat.v4f[1] = splat.v4f[0];
+ break;
+ }
+ return splat;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{
+ // FIXME: No intrinsic yet
+ EIGEN_DEBUG_ALIGNED_LOAD
+ Packet4f vfrom;
+ vfrom.v4f[0] = vec_ld2f(&from[0]);
+ vfrom.v4f[1] = vec_ld2f(&from[2]);
+ return vfrom;
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{
+ // FIXME: No intrinsic yet
+ EIGEN_DEBUG_ALIGNED_STORE
+ vec_st2f(from.v4f[0], &to[0]);
+ vec_st2f(from.v4f[1], &to[2]);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from)
+{
+ Packet4f to;
+ to.v4f[0] = pset1<Packet2d>(static_cast<const double&>(from));
+ to.v4f[1] = to.v4f[0];
+ return to;
+}
+
+template<> EIGEN_STRONG_INLINE void
+pbroadcast4<Packet4f>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ a3 = pload<Packet4f>(a);
+ a0 = vec_splat_packet4f<0>(a3);
+ a1 = vec_splat_packet4f<1>(a3);
+ a2 = vec_splat_packet4f<2>(a3);
+ a3 = vec_splat_packet4f<3>(a3);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
+{
+ float EIGEN_ALIGN16 ai[4];
+ ai[0] = from[0*stride];
+ ai[1] = from[1*stride];
+ ai[2] = from[2*stride];
+ ai[3] = from[3*stride];
+ return pload<Packet4f>(ai);
+}
+
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
float EIGEN_ALIGN16 ai[4];
@@ -487,15 +657,6 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, co
to[3*stride] = ai[3];
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
-{
- double EIGEN_ALIGN16 af[2];
- pstore<double>(af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
-}
-
-template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a + b); }
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f c;
@@ -503,9 +664,7 @@ template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const
c.v4f[1] = a.v4f[1] + b.v4f[1];
return c;
}
-template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a + b); }
-template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a - b); }
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f c;
@@ -513,9 +672,7 @@ template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const
c.v4f[1] = a.v4f[1] - b.v4f[1];
return c;
}
-template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a - b); }
-template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a * b); }
template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f c;
@@ -523,9 +680,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const
c.v4f[1] = a.v4f[1] * b.v4f[1];
return c;
}
-template<> EIGEN_STRONG_INLINE Packet2d pmul<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a * b); }
-template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& a, const Packet4i& b) { return (a / b); }
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f c;
@@ -533,9 +688,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const
c.v4f[1] = a.v4f[1] / b.v4f[1];
return c;
}
-template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const Packet2d& b) { return (a / b); }
-template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return (-a); }
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
{
Packet4f c;
@@ -543,13 +696,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
c.v4f[1] = -a.v4f[1];
return c;
}
-template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return (-a); }
-
-template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd<Packet4i>(pmul<Packet4i>(a, b), c); }
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c)
{
Packet4f res;
@@ -557,14 +704,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f&
res.v4f[1] = vec_madd(a.v4f[1], b.v4f[1], c.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); }
-template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return padd<Packet4i>(pset1<Packet4i>(a), p4i_COUNTDOWN); }
-template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return padd<Packet4f>(pset1<Packet4f>(a), p4f_COUNTDOWN); }
-template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) { return padd<Packet2d>(pset1<Packet2d>(a), p2d_COUNTDOWN); }
-
-template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_min(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
@@ -573,8 +713,6 @@ template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const
return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); }
-template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_max(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
@@ -583,8 +721,6 @@ template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const
return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
-template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
@@ -593,28 +729,22 @@ template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const
return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
-template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
- res.v4f[0] = pand(a.v4f[0], b.v4f[0]);
- res.v4f[1] = pand(a.v4f[1], b.v4f[1]);
+ res.v4f[0] = por(a.v4f[0], b.v4f[0]);
+ res.v4f[1] = por(a.v4f[1], b.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
-template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
- res.v4f[0] = pand(a.v4f[0], b.v4f[0]);
- res.v4f[1] = pand(a.v4f[1], b.v4f[1]);
+ res.v4f[0] = pxor(a.v4f[0], b.v4f[0]);
+ res.v4f[1] = pxor(a.v4f[1], b.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return pand<Packet4i>(a, vec_nor(b, b)); }
-template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); }
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b)
{
Packet4f res;
@@ -630,7 +760,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
res.v4f[1] = vec_round(a.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return vec_round(a); }
+
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
{
Packet4f res;
@@ -638,7 +768,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
res.v4f[1] = vec_ceil(a.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return vec_ceil(a); }
+
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
{
Packet4f res;
@@ -646,18 +776,6 @@ template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
res.v4f[1] = vec_floor(a.v4f[1]);
return res;
}
-template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return vec_floor(a); }
-
-template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from) { return pload<Packet4i>(from); }
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) { return pload<Packet4f>(from); }
-template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from) { return pload<Packet2d>(from); }
-
-
-template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int* from)
-{
- Packet4i p = pload<Packet4i>(from);
- return vec_perm(p, p, p16uc_DUPLICATE32_HI);
-}
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
{
@@ -667,33 +785,7 @@ template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
return p;
}
-template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
-{
- Packet2d p = pload<Packet2d>(from);
- return vec_perm(p, p, p16uc_PSET64_HI);
-}
-
-template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { pstore<int>(to, from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) { pstore<float>(to, from); }
-template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { pstore<double>(to, from); }
-
-template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
-
-template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int EIGEN_ALIGN16 x[4]; pstore(x, a); return x[0]; }
template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x[2]; vec_st2f(a.v4f[0], &x[0]); return x[0]; }
-template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { double EIGEN_ALIGN16 x[2]; pstore(x, a); return x[0]; }
-
-template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
-{
- return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
-{
- return reinterpret_cast<Packet2d>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE64));
-}
template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
{
@@ -703,8 +795,6 @@ template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
return rev;
}
-template<> EIGEN_STRONG_INLINE Packet4i pabs<Packet4i>(const Packet4i& a) { return vec_abs(a); }
-template<> EIGEN_STRONG_INLINE Packet2d pabs<Packet2d>(const Packet2d& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet4f pabs<Packet4f>(const Packet4f& a)
{
Packet4f res;
@@ -713,23 +803,6 @@ template<> EIGEN_STRONG_INLINE Packet4f pabs<Packet4f>(const Packet4f& a)
return res;
}
-template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
-{
- Packet4i b, sum;
- b = vec_sld(a, a, 8);
- sum = padd<Packet4i>(a, b);
- b = vec_sld(sum, sum, 4);
- sum = padd<Packet4i>(sum, b);
- return pfirst(sum);
-}
-
-template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
-{
- Packet2d b, sum;
- b = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8));
- sum = padd<Packet2d>(a, b);
- return pfirst(sum);
-}
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
Packet2d sum;
@@ -738,94 +811,12 @@ template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
return static_cast<float>(first);
}
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
-{
- Packet4i v[4], sum[4];
-
- // It's easier and faster to transpose then add as columns
- // Check: http://www.freevec.org/function/matrix_4x4_transpose_floats for explanation
- // Do the transpose, first set of moves
- v[0] = vec_mergeh(vecs[0], vecs[2]);
- v[1] = vec_mergel(vecs[0], vecs[2]);
- v[2] = vec_mergeh(vecs[1], vecs[3]);
- v[3] = vec_mergel(vecs[1], vecs[3]);
- // Get the resulting vectors
- sum[0] = vec_mergeh(v[0], v[2]);
- sum[1] = vec_mergel(v[0], v[2]);
- sum[2] = vec_mergeh(v[1], v[3]);
- sum[3] = vec_mergel(v[1], v[3]);
-
- // Now do the summation:
- // Lines 0+1
- sum[0] = padd<Packet4i>(sum[0], sum[1]);
- // Lines 2+3
- sum[1] = padd<Packet4i>(sum[2], sum[3]);
- // Add the results
- sum[0] = padd<Packet4i>(sum[0], sum[1]);
-
- return sum[0];
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- Packet2d v[2], sum;
- v[0] = padd<Packet2d>(vecs[0], reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(vecs[0]), reinterpret_cast<Packet4ui>(vecs[0]), 8)));
- v[1] = padd<Packet2d>(vecs[1], reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(vecs[1]), reinterpret_cast<Packet4ui>(vecs[1]), 8)));
-
- sum = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(v[0]), reinterpret_cast<Packet4ui>(v[1]), 8));
-
- return sum;
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- PacketBlock<Packet4f,4> transpose;
- transpose.packet[0] = vecs[0];
- transpose.packet[1] = vecs[1];
- transpose.packet[2] = vecs[2];
- transpose.packet[3] = vecs[3];
- ptranspose(transpose);
-
- Packet4f sum = padd(transpose.packet[0], transpose.packet[1]);
- sum = padd(sum, transpose.packet[2]);
- sum = padd(sum, transpose.packet[3]);
- return sum;
-}
-
-// Other reduction functions:
-// mul
-template<> EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a)
-{
- EIGEN_ALIGN16 int aux[4];
- pstore(aux, a);
- return aux[0] * aux[1] * aux[2] * aux[3];
-}
-
-template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
-{
- return pfirst(pmul(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
-}
-
template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
{
// Return predux_mul<Packet2d> of the subvectors product
return static_cast<float>(pfirst(predux_mul(pmul(a.v4f[0], a.v4f[1]))));
}
-// min
-template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
-{
- Packet4i b, res;
- b = pmin<Packet4i>(a, vec_sld(a, a, 8));
- res = pmin<Packet4i>(b, vec_sld(b, b, 4));
- return pfirst(res);
-}
-
-template<> EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a)
-{
- return pfirst(pmin<Packet2d>(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
-}
-
template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
{
Packet2d b, res;
@@ -834,21 +825,6 @@ template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
return static_cast<float>(pfirst(res));
}
-// max
-template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
-{
- Packet4i b, res;
- b = pmax<Packet4i>(a, vec_sld(a, a, 8));
- res = pmax<Packet4i>(b, vec_sld(b, b, 4));
- return pfirst(res);
-}
-
-// max
-template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
-{
- return pfirst(pmax<Packet2d>(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4i>(a), reinterpret_cast<Packet4i>(a), 8))));
-}
-
template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
{
Packet2d b, res;
@@ -857,26 +833,6 @@ template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
return static_cast<float>(pfirst(res));
}
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4i,4>& kernel) {
- Packet4i t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
- Packet4i t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
- Packet4i t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
- Packet4i t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
- kernel.packet[0] = vec_mergeh(t0, t2);
- kernel.packet[1] = vec_mergel(t0, t2);
- kernel.packet[2] = vec_mergeh(t1, t3);
- kernel.packet[3] = vec_mergel(t1, t3);
-}
-
-EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet2d,2>& kernel) {
- Packet2d t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI);
- Packet2d t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO);
- kernel.packet[0] = t0;
- kernel.packet[1] = t1;
-}
-
/* Split the Packet4f PacketBlock into 4 Packet2d PacketBlocks and transpose each one
*/
EIGEN_DEVICE_FUNC inline void
@@ -915,12 +871,6 @@ ptranspose(PacketBlock<Packet4f,4>& kernel) {
kernel.packet[3].v4f[1] = t3.packet[1];
}
-template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
- Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
- Packet4ui mask = vec_cmpeq(select, reinterpret_cast<Packet4ui>(p4i_ONE));
- return vec_sel(elsePacket, thenPacket, mask);
-}
-
template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) {
Packet2ul select_hi = { ifPacket.select[0], ifPacket.select[1] };
Packet2ul select_lo = { ifPacket.select[2], ifPacket.select[3] };
@@ -932,12 +882,177 @@ template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, cons
return result;
}
-template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
- Packet2ul select = { ifPacket.select[0], ifPacket.select[1] };
- Packet2ul mask = vec_cmpeq(select, reinterpret_cast<Packet2ul>(p2l_ONE));
+template<> Packet4f EIGEN_STRONG_INLINE pcmp_le<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ Packet4f res;
+ res.v4f[0] = pcmp_le(a.v4f[0], b.v4f[0]);
+ res.v4f[1] = pcmp_le(a.v4f[1], b.v4f[1]);
+ return res;
+}
+
+template<> Packet4f EIGEN_STRONG_INLINE pcmp_lt<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ Packet4f res;
+ res.v4f[0] = pcmp_lt(a.v4f[0], b.v4f[0]);
+ res.v4f[1] = pcmp_lt(a.v4f[1], b.v4f[1]);
+ return res;
+}
+
+template<> Packet4f EIGEN_STRONG_INLINE pcmp_eq<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ Packet4f res;
+ res.v4f[0] = pcmp_eq(a.v4f[0], b.v4f[0]);
+ res.v4f[1] = pcmp_eq(a.v4f[1], b.v4f[1]);
+ return res;
+}
+
+#else
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{
+ // FIXME: No intrinsic yet
+ EIGEN_DEBUG_ALIGNED_LOAD
+ Packet *vfrom;
+ vfrom = (Packet *) from;
+ return vfrom->v4f;
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{
+ // FIXME: No intrinsic yet
+ EIGEN_DEBUG_ALIGNED_STORE
+ Packet *vto;
+ vto = (Packet *) to;
+ vto->v4f = from;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from)
+{
+ return vec_splats(from);
+}
+
+template<> EIGEN_STRONG_INLINE void
+pbroadcast4<Packet4f>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ a3 = pload<Packet4f>(a);
+ a0 = vec_splat(a3, 0);
+ a1 = vec_splat(a3, 1);
+ a2 = vec_splat(a3, 2);
+ a3 = vec_splat(a3, 3);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
+{
+ float EIGEN_ALIGN16 af[4];
+ af[0] = from[0*stride];
+ af[1] = from[1*stride];
+ af[2] = from[2*stride];
+ af[3] = from[3*stride];
+ return pload<Packet4f>(af);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
+{
+ float EIGEN_ALIGN16 af[4];
+ pstore<float>((float*)af, from);
+ to[0*stride] = af[0];
+ to[1*stride] = af[1];
+ to[2*stride] = af[2];
+ to[3*stride] = af[3];
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return (a + b); }
+template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return (a - b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) { return (a * b); }
+template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b) { return (a / b); }
+template<> EIGEN_STRONG_INLINE Packet4f pnegate<Packet4f>(const Packet4f& a) { return (-a); }
+template<> EIGEN_STRONG_INLINE Packet4f pconj<Packet4f> (const Packet4f& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4f pmadd<Packet4f> (const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_madd(a, b, c); }
+template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f> (const Packet4f& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f> (const Packet4f& a) { return vec_ceil(a); }
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f> (const Packet4f& a) { return vec_floor(a); }
+template<> EIGEN_STRONG_INLINE Packet4f pabs<Packet4f> (const Packet4f& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x[4]; pstore(x, a); return x[0]; }
+
+template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+{
+ Packet4f p = pload<Packet4f>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE32_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
+{
+ return reinterpret_cast<Packet4f>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32));
+}
+
+template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
+{
+ Packet4f b, sum;
+ b = vec_sld(a, a, 8);
+ sum = padd<Packet4f>(a, b);
+ b = vec_sld(sum, sum, 4);
+ sum = padd<Packet4f>(sum, b);
+ return pfirst(sum);
+}
+
+// Other reduction functions:
+// mul
+template<> EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a)
+{
+ Packet4f prod;
+ prod = pmul(a, vec_sld(a, a, 8));
+ return pfirst(pmul(prod, vec_sld(prod, prod, 4)));
+}
+
+// min
+template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+{
+ Packet4f b, res;
+ b = pmin<Packet4f>(a, vec_sld(a, a, 8));
+ res = pmin<Packet4f>(b, vec_sld(b, b, 4));
+ return pfirst(res);
+}
+
+// max
+template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+{
+ Packet4f b, res;
+ b = pmax<Packet4f>(a, vec_sld(a, a, 8));
+ res = pmax<Packet4f>(b, vec_sld(b, b, 4));
+ return pfirst(res);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet4f,4>& kernel) {
+ Packet4f t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ Packet4f t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ Packet4f t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ Packet4f t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) {
+ Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
+ Packet4ui mask = vec_cmpeq(select, reinterpret_cast<Packet4ui>(p4i_ONE));
return vec_sel(elsePacket, thenPacket, mask);
}
+#endif
+
+template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_ZVECTOR_PREFETCH(addr); }
+template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f> (const float* from) { return pload<Packet4f>(from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) { pstore<float>(to, from); }
+template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f> (const float& a) { return padd<Packet4f>(pset1<Packet4f>(a), p4f_COUNTDOWN); }
+
} // end namespace internal
} // end namespace Eigen