aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/arch/AltiVec
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/AltiVec')
-rw-r--r--Eigen/src/Core/arch/AltiVec/Complex.h352
-rw-r--r--Eigen/src/Core/arch/AltiVec/MathFunctions.h270
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProduct.h2937
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h221
-rw-r--r--Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h629
-rwxr-xr-xEigen/src/Core/arch/AltiVec/PacketMath.h2384
6 files changed, 5991 insertions, 802 deletions
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index 67db2f8ee..f424f11cf 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -29,8 +29,54 @@ static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2d_MZERO, (P
//---------- float ----------
struct Packet2cf
{
- EIGEN_STRONG_INLINE explicit Packet2cf() : v(p4f_ZERO) {}
+ EIGEN_STRONG_INLINE explicit Packet2cf() {}
EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {}
+
+ EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b)
+ {
+ Packet4f v1, v2;
+
+ // Permute and multiply the real parts of a and b
+ v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD);
+ // Get the imaginary parts of a
+ v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN);
+ // multiply a_re * b
+ v1 = vec_madd(v1, b.v, p4f_ZERO);
+ // multiply a_im * b and get the conjugate result
+ v2 = vec_madd(v2, b.v, p4f_ZERO);
+ v2 = reinterpret_cast<Packet4f>(pxor(v2, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR)));
+ // permute back to a proper order
+ v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV);
+
+ return Packet2cf(padd<Packet4f>(v1, v2));
+ }
+
+ EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) {
+ v = pmul(Packet2cf(*this), b).v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const {
+ return Packet2cf(*this) *= b;
+ }
+
+ EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const {
+ return Packet2cf(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const {
+ return Packet2cf(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet2cf operator-(void) const {
+ return Packet2cf(-v);
+ }
+
Packet4f v;
};
@@ -38,6 +84,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
{
typedef Packet2cf type;
typedef Packet2cf half;
+ typedef Packet4f as_real;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
@@ -60,7 +107,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; typedef Packet4f as_real; };
template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
@@ -80,16 +127,35 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
+EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
+{
+ Packet4f res0, res1;
+#ifdef __VSX__
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
+#ifdef _BIG_ENDIAN
+ __asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
+#else
+ __asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
+#endif
+#else
+ *reinterpret_cast<std::complex<float> *>(&res0) = *from0;
+ *reinterpret_cast<std::complex<float> *>(&res1) = *from1;
+ res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
+#endif
+ return Packet2cf(res0);
+}
+
template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather<std::complex<float>, Packet2cf>(const std::complex<float>* from, Index stride)
{
- std::complex<float> EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 std::complex<float> af[2];
af[0] = from[0*stride];
af[1] = from[1*stride];
return pload<Packet2cf>(af);
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet2cf>(std::complex<float>* to, const Packet2cf& from, Index stride)
{
- std::complex<float> EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 std::complex<float> af[2];
pstore<std::complex<float> >((std::complex<float> *) af, from);
to[0*stride] = af[0];
to[1*stride] = af[1];
@@ -100,25 +166,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, con
template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(a.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor<Packet4f>(a.v, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR))); }
-template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
-{
- Packet4f v1, v2;
-
- // Permute and multiply the real parts of a and b
- v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD);
- // Get the imaginary parts of a
- v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN);
- // multiply a_re * b
- v1 = vec_madd(v1, b.v, p4f_ZERO);
- // multiply a_im * b and get the conjugate result
- v2 = vec_madd(v2, b.v, p4f_ZERO);
- v2 = reinterpret_cast<Packet4f>(pxor(v2, reinterpret_cast<Packet4f>(p4ui_CONJ_XOR)));
- // permute back to a proper order
- v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV);
-
- return Packet2cf(padd<Packet4f>(v1, v2));
-}
-
template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por<Packet4f>(a.v, b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor<Packet4f>(a.v, b.v)); }
@@ -128,7 +175,7 @@ template<> EIGEN_STRONG_INLINE void prefetch<std::complex<float> >(const std::co
template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet2cf>(const Packet2cf& a)
{
- std::complex<float> EIGEN_ALIGN16 res[2];
+ EIGEN_ALIGN16 std::complex<float> res[2];
pstore((float *)&res, a.v);
return res[0];
@@ -149,22 +196,6 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet2cf>(const Packe
return pfirst<Packet2cf>(Packet2cf(b));
}
-template<> EIGEN_STRONG_INLINE Packet2cf preduxp<Packet2cf>(const Packet2cf* vecs)
-{
- Packet4f b1, b2;
-#ifdef _BIG_ENDIAN
- b1 = vec_sld(vecs[0].v, vecs[1].v, 8);
- b2 = vec_sld(vecs[1].v, vecs[0].v, 8);
-#else
- b1 = vec_sld(vecs[1].v, vecs[0].v, 8);
- b2 = vec_sld(vecs[0].v, vecs[1].v, 8);
-#endif
- b2 = vec_sld(b2, b2, 8);
- b2 = padd<Packet4f>(b1, b2);
-
- return Packet2cf(b2);
-}
-
template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const Packet2cf& a)
{
Packet4f b;
@@ -175,77 +206,12 @@ template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet2cf>(const P
return pfirst<Packet2cf>(prod);
}
-template<int Offset>
-struct palign_impl<Offset,Packet2cf>
-{
- static EIGEN_STRONG_INLINE void run(Packet2cf& first, const Packet2cf& second)
- {
- if (Offset==1)
- {
-#ifdef _BIG_ENDIAN
- first.v = vec_sld(first.v, second.v, 8);
-#else
- first.v = vec_sld(second.v, first.v, 8);
-#endif
- }
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, false,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet2cf, Packet2cf, true,true>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-
-template<> struct conj_helper<Packet4f, Packet2cf, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet4f& x, const Packet2cf& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet4f& x, const Packet2cf& y) const
- { return Packet2cf(internal::pmul<Packet4f>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet2cf, Packet4f, false,false>
-{
- EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& x, const Packet4f& y, const Packet2cf& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& x, const Packet4f& y) const
- { return Packet2cf(internal::pmul<Packet4f>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f)
template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, const Packet2cf& b)
{
// TODO optimize it for AltiVec
- Packet2cf res = conj_helper<Packet2cf,Packet2cf,false,true>().pmul(a, b);
+ Packet2cf res = pmul(a, pconj(b));
Packet4f s = pmul<Packet4f>(b.v, b.v);
return Packet2cf(pdiv(res.v, padd<Packet4f>(s, vec_perm(s, s, p16uc_COMPLEX32_REV))));
}
@@ -262,6 +228,11 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
kernel.packet[0].v = tmp;
}
+template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) {
+ Packet4f eq = reinterpret_cast<Packet4f>(vec_cmpeq(a.v,b.v));
+ return Packet2cf(vec_and(eq, vec_perm(eq, eq, p16uc_COMPLEX32_REV)));
+}
+
#ifdef __VSX__
template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) {
Packet2cf result;
@@ -270,12 +241,62 @@ template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, con
}
#endif
+template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a)
+{
+ return psqrt_complex<Packet2cf>(a);
+}
+
//---------- double ----------
#ifdef __VSX__
struct Packet1cd
{
EIGEN_STRONG_INLINE Packet1cd() {}
EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {}
+
+ EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b)
+ {
+ Packet2d a_re, a_im, v1, v2;
+
+ // Permute and multiply the real parts of a and b
+ a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI);
+ // Get the imaginary parts of a
+ a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO);
+ // multiply a_re * b
+ v1 = vec_madd(a_re, b.v, p2d_ZERO);
+ // multiply a_im * b and get the conjugate result
+ v2 = vec_madd(a_im, b.v, p2d_ZERO);
+ v2 = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(v2), reinterpret_cast<Packet4ui>(v2), 8));
+ v2 = pxor(v2, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR1));
+
+ return Packet1cd(padd<Packet2d>(v1, v2));
+ }
+
+ EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) {
+ v = pmul(Packet1cd(*this), b).v;
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const {
+ return Packet1cd(*this) *= b;
+ }
+
+ EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) {
+ v = padd(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const {
+ return Packet1cd(*this) += b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) {
+ v = psub(v, b.v);
+ return *this;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const {
+ return Packet1cd(*this) -= b;
+ }
+ EIGEN_STRONG_INLINE Packet1cd operator-(void) const {
+ return Packet1cd(-v);
+ }
+
Packet2d v;
};
@@ -283,6 +304,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
{
typedef Packet1cd type;
typedef Packet1cd half;
+ typedef Packet2d as_real;
enum {
Vectorizable = 1,
AlignedOnScalar = 0,
@@ -302,7 +324,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
};
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; typedef Packet2d as_real; };
template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from) { return Packet1cd(pload<Packet2d>((const double*)from)); }
template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { return Packet1cd(ploadu<Packet2d>((const double*)from)); }
@@ -312,19 +334,13 @@ template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<
template<> EIGEN_STRONG_INLINE Packet1cd pset1<Packet1cd>(const std::complex<double>& from)
{ /* here we really have to use unaligned loads :( */ return ploadu<Packet1cd>(&from); }
-template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather<std::complex<double>, Packet1cd>(const std::complex<double>* from, Index)
{
- std::complex<double> EIGEN_ALIGN16 af[2];
- af[0] = from[0*stride];
- af[1] = from[1*stride];
- return pload<Packet1cd>(af);
+ return pload<Packet1cd>(from);
}
-template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index stride)
+template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet1cd>(std::complex<double>* to, const Packet1cd& from, Index)
{
- std::complex<double> EIGEN_ALIGN16 af[2];
- pstore<std::complex<double> >(af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
+ pstore<std::complex<double> >(to, from);
}
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v + b.v); }
@@ -332,24 +348,6 @@ template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, con
template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); }
template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR2))); }
-template<> EIGEN_STRONG_INLINE Packet1cd pmul<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
-{
- Packet2d a_re, a_im, v1, v2;
-
- // Permute and multiply the real parts of a and b
- a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI);
- // Get the imaginary parts of a
- a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO);
- // multiply a_re * b
- v1 = vec_madd(a_re, b.v, p2d_ZERO);
- // multiply a_im * b and get the conjugate result
- v2 = vec_madd(a_im, b.v, p2d_ZERO);
- v2 = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(v2), reinterpret_cast<Packet4ui>(v2), 8));
- v2 = pxor(v2, reinterpret_cast<Packet2d>(p2ul_CONJ_XOR1));
-
- return Packet1cd(padd<Packet2d>(v1, v2));
-}
-
template<> EIGEN_STRONG_INLINE Packet1cd pand <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pand(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd por <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(por(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd pxor <Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pxor(a.v,b.v)); }
@@ -361,7 +359,7 @@ template<> EIGEN_STRONG_INLINE void prefetch<std::complex<double> >(const std::c
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
- std::complex<double> EIGEN_ALIGN16 res[2];
+ EIGEN_ALIGN16 std::complex<double> res[2];
pstore<std::complex<double> >(res, a);
return res[0];
@@ -370,74 +368,15 @@ template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Pac
template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; }
template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<> EIGEN_STRONG_INLINE Packet1cd preduxp<Packet1cd>(const Packet1cd* vecs) { return vecs[0]; }
template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet1cd>(const Packet1cd& a) { return pfirst(a); }
-template<int Offset>
-struct palign_impl<Offset,Packet1cd>
-{
- static EIGEN_STRONG_INLINE void run(Packet1cd& /*first*/, const Packet1cd& /*second*/)
- {
- // FIXME is it sure we never have to align a Packet1cd?
- // Even though a std::complex<double> has 16 bytes, it is not necessarily aligned on a 16 bytes boundary...
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, false,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(a, pconj(b));
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return internal::pmul(pconj(a), b);
- }
-};
-
-template<> struct conj_helper<Packet1cd, Packet1cd, true,true>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(pmul(x,y),c); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) const
- {
- return pconj(internal::pmul(a, b));
- }
-};
-template<> struct conj_helper<Packet2d, Packet1cd, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet2d& x, const Packet1cd& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet2d& x, const Packet1cd& y) const
- { return Packet1cd(internal::pmul<Packet2d>(x, y.v)); }
-};
-
-template<> struct conj_helper<Packet1cd, Packet2d, false,false>
-{
- EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& x, const Packet2d& y, const Packet1cd& c) const
- { return padd(c, pmul(x,y)); }
-
- EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& x, const Packet2d& y) const
- { return Packet1cd(internal::pmul<Packet2d>(x.v, y)); }
-};
+EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d)
template<> EIGEN_STRONG_INLINE Packet1cd pdiv<Packet1cd>(const Packet1cd& a, const Packet1cd& b)
{
// TODO optimize it for AltiVec
- Packet1cd res = conj_helper<Packet1cd,Packet1cd,false,true>().pmul(a,b);
+ Packet1cd res = pmul(a,pconj(b));
Packet2d s = pmul<Packet2d>(b.v, b.v);
return Packet1cd(pdiv(res.v, padd<Packet2d>(s, vec_perm(s, s, p16uc_REVERSE64))));
}
@@ -453,6 +392,23 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
kernel.packet[0].v = tmp;
}
+
+template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) {
+ // Compare real and imaginary parts of a and b to get the mask vector:
+ // [re(a)==re(b), im(a)==im(b)]
+ Packet2d eq = reinterpret_cast<Packet2d>(vec_cmpeq(a.v,b.v));
+ // Swap real/imag elements in the mask in to get:
+ // [im(a)==im(b), re(a)==re(b)]
+ Packet2d eq_swapped = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(eq), reinterpret_cast<Packet4ui>(eq), 8));
+ // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped
+ return Packet1cd(vec_and(eq, eq_swapped));
+}
+
+template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a)
+{
+ return psqrt_complex<Packet1cd>(a);
+}
+
#endif // __VSX__
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AltiVec/MathFunctions.h b/Eigen/src/Core/arch/AltiVec/MathFunctions.h
index c5e4bede7..3a7a32936 100644
--- a/Eigen/src/Core/arch/AltiVec/MathFunctions.h
+++ b/Eigen/src/Core/arch/AltiVec/MathFunctions.h
@@ -9,10 +9,6 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-/* The sin, cos, exp, and log functions of this file come from
- * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
- */
-
#ifndef EIGEN_MATH_FUNCTIONS_ALTIVEC_H
#define EIGEN_MATH_FUNCTIONS_ALTIVEC_H
@@ -20,180 +16,28 @@ namespace Eigen {
namespace internal {
-static _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f);
-static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f);
-static _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f);
-static _EIGEN_DECLARE_CONST_Packet4i(23, 23);
-
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000);
-
-/* the smallest non denormalized float number */
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000);
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000); // -1.f/0.f
-static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_nan, 0xffffffff);
-
-/* natural logarithm computed for 4 simultaneous float
- return NaN for x <= 0
-*/
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f);
-static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f);
-
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f);
-static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f);
-
-#ifdef __VSX__
-static _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0);
-static _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0);
-static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5);
-
-static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, 709.437);
-static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -709.436139303);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0);
-
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125);
-static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6);
-
-#ifdef __POWER8_VECTOR__
-static Packet2l p2l_1023 = { 1023, 1023 };
-static Packet2ul p2ul_52 = { 52, 52 };
-#endif
-
-#endif
-
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
-
- Packet4i emm0;
-
- /* isvalid_mask is 0 if x < 0 or x is NaN. */
- Packet4ui isvalid_mask = reinterpret_cast<Packet4ui>(vec_cmpge(x, p4f_ZERO));
- Packet4ui iszero_mask = reinterpret_cast<Packet4ui>(vec_cmpeq(x, p4f_ZERO));
-
- x = pmax(x, p4f_min_norm_pos); /* cut off denormalized stuff */
- emm0 = vec_sr(reinterpret_cast<Packet4i>(x),
- reinterpret_cast<Packet4ui>(p4i_23));
-
- /* keep only the fractional part */
- x = pand(x, p4f_inv_mant_mask);
- x = por(x, p4f_half);
-
- emm0 = psub(emm0, p4i_0x7f);
- Packet4f e = padd(vec_ctf(emm0, 0), p4f_1);
-
- /* part2:
- if( x < SQRTHF ) {
- e -= 1;
- x = x + x - 1.0;
- } else { x = x - 1.0; }
- */
- Packet4f mask = reinterpret_cast<Packet4f>(vec_cmplt(x, p4f_cephes_SQRTHF));
- Packet4f tmp = pand(x, mask);
- x = psub(x, p4f_1);
- e = psub(e, pand(p4f_1, mask));
- x = padd(x, tmp);
-
- Packet4f x2 = pmul(x,x);
- Packet4f x3 = pmul(x2,x);
-
- Packet4f y, y1, y2;
- y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1);
- y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4);
- y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7);
- y = pmadd(y , x, p4f_cephes_log_p2);
- y1 = pmadd(y1, x, p4f_cephes_log_p5);
- y2 = pmadd(y2, x, p4f_cephes_log_p8);
- y = pmadd(y, x3, y1);
- y = pmadd(y, x3, y2);
- y = pmul(y, x3);
-
- y1 = pmul(e, p4f_cephes_log_q1);
- tmp = pmul(x2, p4f_half);
- y = padd(y, y1);
- x = psub(x, tmp);
- y2 = pmul(e, p4f_cephes_log_q2);
- x = padd(x, y);
- x = padd(x, y2);
- // negative arg will be NAN, 0 will be -INF
- x = vec_sel(x, p4f_minus_inf, iszero_mask);
- x = vec_sel(p4f_minus_nan, x, isvalid_mask);
- return x;
+ return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexp<Packet4f>(const Packet4f& _x)
{
- Packet4f x = _x;
-
- Packet4f tmp, fx;
- Packet4i emm0;
-
- // clamp x
- x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo);
-
- // express exp(x) as exp(g + n*log(2))
- fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half);
-
- fx = pfloor(fx);
-
- tmp = pmul(fx, p4f_cephes_exp_C1);
- Packet4f z = pmul(fx, p4f_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- z = pmul(x,x);
-
- Packet4f y = p4f_cephes_exp_p0;
- y = pmadd(y, x, p4f_cephes_exp_p1);
- y = pmadd(y, x, p4f_cephes_exp_p2);
- y = pmadd(y, x, p4f_cephes_exp_p3);
- y = pmadd(y, x, p4f_cephes_exp_p4);
- y = pmadd(y, x, p4f_cephes_exp_p5);
- y = pmadd(y, z, x);
- y = padd(y, p4f_1);
+ return pexp_float(_x);
+}
- // build 2^n
- emm0 = vec_cts(fx, 0);
- emm0 = vec_add(emm0, p4i_0x7f);
- emm0 = vec_sl(emm0, reinterpret_cast<Packet4ui>(p4i_23));
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f psin<Packet4f>(const Packet4f& _x)
+{
+ return psin_float(_x);
+}
- // Altivec's max & min operators just drop silent NaNs. Check NaNs in
- // inputs and return them unmodified.
- Packet4ui isnumber_mask = reinterpret_cast<Packet4ui>(vec_cmpeq(_x, _x));
- return vec_sel(_x, pmax(pmul(y, reinterpret_cast<Packet4f>(emm0)), _x),
- isnumber_mask);
+template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
+Packet4f pcos<Packet4f>(const Packet4f& _x)
+{
+ return pcos_float(_x);
}
#ifndef EIGEN_COMP_CLANG
@@ -225,95 +69,19 @@ Packet2d psqrt<Packet2d>(const Packet2d& x)
return vec_sqrt(x);
}
-// VSX support varies between different compilers and even different
-// versions of the same compiler. For gcc version >= 4.9.3, we can use
-// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
-// a slow version that works with older compilers.
-// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
-// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
-static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
-#if EIGEN_GNUC_AT_LEAST(5, 4) || \
- (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
- return vec_cts(x, 0); // TODO: check clang version.
-#else
- double tmp[2];
- memcpy(tmp, &x, sizeof(tmp));
- Packet2l l = { static_cast<long long>(tmp[0]),
- static_cast<long long>(tmp[1]) };
- return l;
-#endif
-}
-
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d pexp<Packet2d>(const Packet2d& _x)
{
- Packet2d x = _x;
-
- Packet2d tmp, fx;
- Packet2l emm0;
-
- // clamp x
- x = pmax(pmin(x, p2d_exp_hi), p2d_exp_lo);
-
- /* express exp(x) as exp(g + n*log(2)) */
- fx = pmadd(x, p2d_cephes_LOG2EF, p2d_half);
-
- fx = pfloor(fx);
-
- tmp = pmul(fx, p2d_cephes_exp_C1);
- Packet2d z = pmul(fx, p2d_cephes_exp_C2);
- x = psub(x, tmp);
- x = psub(x, z);
-
- Packet2d x2 = pmul(x,x);
-
- Packet2d px = p2d_cephes_exp_p0;
- px = pmadd(px, x2, p2d_cephes_exp_p1);
- px = pmadd(px, x2, p2d_cephes_exp_p2);
- px = pmul (px, x);
-
- Packet2d qx = p2d_cephes_exp_q0;
- qx = pmadd(qx, x2, p2d_cephes_exp_q1);
- qx = pmadd(qx, x2, p2d_cephes_exp_q2);
- qx = pmadd(qx, x2, p2d_cephes_exp_q3);
-
- x = pdiv(px,psub(qx,px));
- x = pmadd(p2d_2,x,p2d_1);
-
- // build 2^n
- emm0 = ConvertToPacket2l(fx);
-
-#ifdef __POWER8_VECTOR__
- emm0 = vec_add(emm0, p2l_1023);
- emm0 = vec_sl(emm0, p2ul_52);
-#else
- // Code is a bit complex for POWER7. There is actually a
- // vec_xxsldi intrinsic but it is not supported by some gcc versions.
- // So we shift (52-32) bits and do a word swap with zeros.
- _EIGEN_DECLARE_CONST_Packet4i(1023, 1023);
- _EIGEN_DECLARE_CONST_Packet4i(20, 20); // 52 - 32
-
- Packet4i emm04i = reinterpret_cast<Packet4i>(emm0);
- emm04i = vec_add(emm04i, p4i_1023);
- emm04i = vec_sl(emm04i, reinterpret_cast<Packet4ui>(p4i_20));
- static const Packet16uc perm = {
- 0x14, 0x15, 0x16, 0x17, 0x00, 0x01, 0x02, 0x03,
- 0x1c, 0x1d, 0x1e, 0x1f, 0x08, 0x09, 0x0a, 0x0b };
-#ifdef _BIG_ENDIAN
- emm0 = reinterpret_cast<Packet2l>(vec_perm(p4i_ZERO, emm04i, perm));
-#else
- emm0 = reinterpret_cast<Packet2l>(vec_perm(emm04i, p4i_ZERO, perm));
-#endif
-
+ return pexp_double(_x);
+}
#endif
- // Altivec's max & min operators just drop silent NaNs. Check NaNs in
- // inputs and return them unmodified.
- Packet2ul isnumber_mask = reinterpret_cast<Packet2ul>(vec_cmpeq(_x, _x));
- return vec_sel(_x, pmax(pmul(x, reinterpret_cast<Packet2d>(emm0)), _x),
- isnumber_mask);
+// Hyperbolic Tangent function.
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f
+ptanh<Packet4f>(const Packet4f& x) {
+ return internal::generic_fast_tanh_float(x);
}
-#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
new file mode 100644
index 000000000..3f79b97df
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -0,0 +1,2937 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
+// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
+#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
+
+#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
+#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
+#endif
+
+#include "MatrixProductCommon.h"
+
+// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
+#if EIGEN_COMP_LLVM
+#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY)
+#ifdef __MMA__
+#define EIGEN_ALTIVEC_MMA_ONLY
+#else
+#define EIGEN_ALTIVEC_DISABLE_MMA
+#endif
+#endif
+#endif
+
+#ifdef __has_builtin
+#if __has_builtin(__builtin_mma_assemble_acc)
+ #define ALTIVEC_MMA_SUPPORT
+#endif
+#endif
+
+#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ #include "MatrixProductMMA.h"
+#endif
+
+/**************************************************************************************************
+ * TODO *
+ * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
+ * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
+ **************************************************************************************************/
+namespace Eigen {
+
+namespace internal {
+
+/**************************
+ * Constants and typedefs *
+ **************************/
+template<typename Scalar>
+struct quad_traits
+{
+ typedef typename packet_traits<Scalar>::type vectortype;
+ typedef PacketBlock<vectortype,4> type;
+ typedef vectortype rhstype;
+ enum
+ {
+ vectorsize = packet_traits<Scalar>::size,
+ size = 4,
+ rows = 4
+ };
+};
+
+template<>
+struct quad_traits<double>
+{
+ typedef Packet2d vectortype;
+ typedef PacketBlock<vectortype,4> type;
+ typedef PacketBlock<Packet2d,2> rhstype;
+ enum
+ {
+ vectorsize = packet_traits<double>::size,
+ size = 2,
+ rows = 4
+ };
+};
+
+// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
+// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
+// are responsible to extract from convert between Eigen's and MatrixProduct approach.
+
+const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
+ 8, 9, 10, 11,
+ 16, 17, 18, 19,
+ 24, 25, 26, 27};
+
+const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
+ 12, 13, 14, 15,
+ 20, 21, 22, 23,
+ 28, 29, 30, 31};
+const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,ai],[b,bi] = [ai,bi]
+const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+/*********************************************
+ * Single precision real and complex packing *
+ * *******************************************/
+
+/**
+ * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves
+ * the diagonal real, whatever is below it is copied from the respective upper diagonal element and
+ * conjugated. There's no PanelMode available for symm packing.
+ *
+ * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using
+ * its respective rank-update instructions. The float32/64 versions are different because at this moment
+ * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements.
+ *
+ * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has
+ * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main
+ * reason why packing for complex is broken down into several different parts, also the reason why we endup having a
+ * float32/64 and complex float32/64 version.
+ **/
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
+{
+ std::complex<Scalar> v;
+ if(i < j)
+ {
+ v.real( dt(j,i).real());
+ v.imag(-dt(j,i).imag());
+ } else if(i > j)
+ {
+ v.real( dt(i,j).real());
+ v.imag( dt(i,j).imag());
+ } else {
+ v.real( dt(i,j).real());
+ v.imag((Scalar)0.0);
+ }
+ return v;
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * rows;
+ Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
+
+ Index rir = 0, rii, j = 0;
+ for(; j + vectorSize <= cols; j+=vectorSize)
+ {
+ rii = rir + vectorDelta;
+
+ for(Index i = k2; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
+
+ blockBf[rir + k] = v.real();
+ blockBf[rii + k] = v.imag();
+ }
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += vectorDelta;
+ }
+ if (j < cols)
+ {
+ rii = rir + ((cols - j) * rows);
+
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+
+ blockBf[rir] = v.real();
+ blockBf[rii] = v.imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
+{
+ const Index depth = cols;
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * depth;
+ Scalar* blockAf = (Scalar *)(blockA);
+
+ Index rir = 0, rii, j = 0;
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ rii = rir + vectorDelta;
+
+ for(Index i = 0; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
+
+ blockAf[rir + k] = v.real();
+ blockAf[rii + k] = v.imag();
+ }
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += vectorDelta;
+ }
+
+ if (j < rows)
+ {
+ rii = rir + ((rows - j) * depth);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
+
+ blockAf[rir] = v.real();
+ blockAf[rii] = v.imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder, int N>
+EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+{
+ const Index depth = k2 + rows;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+
+ Index ri = 0, j = 0;
+ for(; j + N*vectorSize <= cols; j+=N*vectorSize)
+ {
+ Index i = k2;
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < N*vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockB[ri + k] = rhs(j+k, i);
+ else
+ blockB[ri + k] = rhs(i, j+k);
+ }
+ ri += N*vectorSize;
+ }
+ }
+
+ if (j < cols)
+ {
+ for(Index i = k2; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ if(k <= i)
+ blockB[ri] = rhs(i, k);
+ else
+ blockB[ri] = rhs(k, i);
+ ri += 1;
+ }
+ }
+ }
+}
+
+template<typename Scalar, typename Index, int StorageOrder>
+EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
+{
+ const Index depth = cols;
+ const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+
+ Index ri = 0, j = 0;
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ for(; i < depth; i++)
+ {
+ for(Index k = 0; k < vectorSize; k++)
+ {
+ if(i <= j+k)
+ blockA[ri + k] = lhs(j+k, i);
+ else
+ blockA[ri + k] = lhs(i, j+k);
+ }
+ ri += vectorSize;
+ }
+ }
+
+ if (j < rows)
+ {
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if(i <= k)
+ blockA[ri] = lhs(k, i);
+ else
+ blockA[ri] = lhs(i, k);
+ ri += 1;
+ }
+ }
+ }
+}
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack std::complex<float64> ***********
+
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
+{
+ void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float32 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<float, Index, nr, StorageOrder>
+{
+ void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+// *********** symm_pack float64 ***********
+template<typename Index, int nr, int StorageOrder>
+struct symm_pack_rhs<double, Index, nr, StorageOrder>
+{
+ void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
+ {
+ symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
+ }
+};
+
+template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
+struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
+{
+ void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
+ {
+ symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
+ }
+};
+
+/**
+ * PanelMode
+ * Packing might be called several times before being multiplied by gebp_kernel, this happens because
+ * on special occasions it fills part of block with other parts of the matrix. Two variables control
+ * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever
+ * is going to be the real offset and stride in the future and this is what you should obey. The process
+ * is to behave as you would with normal packing but leave the start of each part with the correct offset
+ * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride
+ * and offset and behaves accordingly.
+ **/
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
+{
+ const Index size = 16 / sizeof(Scalar);
+ pstore<Scalar>(to + (0 * size), block.packet[0]);
+ pstore<Scalar>(to + (1 * size), block.packet[1]);
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
+}
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
+{
+ const Index size = 16 / sizeof(Scalar);
+ pstore<Scalar>(to + (0 * size), block.packet[0]);
+ pstore<Scalar>(to + (1 * size), block.packet[1]);
+}
+
+// General template for lhs & rhs complex packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
+struct dhs_cpack {
+ EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
+ Scalar* blockAt = reinterpret_cast<Scalar *>(blockA);
+ Index j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,4> blockr, blocki;
+ PacketBlock<PacketC,8> cblock;
+
+ if (UseLhs) {
+ bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
+ } else {
+ bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
+ }
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
+ blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
+ blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
+ blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
+ blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
+ blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
+ blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ blocki.packet[2] = -blocki.packet[2];
+ blocki.packet[3] = -blocki.packet[3];
+ }
+
+ if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
+ {
+ ptranspose(blockr);
+ ptranspose(blocki);
+ }
+
+ storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
+ storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
+
+ rir += 4*vectorSize;
+ rii += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ PacketBlock<Packet,1> blockr, blocki;
+ PacketBlock<PacketC,2> cblock;
+
+ if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
+ {
+ if (UseLhs) {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
+ }
+ } else {
+ std::complex<Scalar> lhs0, lhs1;
+ if (UseLhs) {
+ lhs0 = lhs(j + 0, i);
+ lhs1 = lhs(j + 1, i);
+ cblock.packet[0] = pload2(&lhs0, &lhs1);
+ lhs0 = lhs(j + 2, i);
+ lhs1 = lhs(j + 3, i);
+ cblock.packet[1] = pload2(&lhs0, &lhs1);
+ } else {
+ lhs0 = lhs(i, j + 0);
+ lhs1 = lhs(i, j + 1);
+ cblock.packet[0] = pload2(&lhs0, &lhs1);
+ lhs0 = lhs(i, j + 2);
+ lhs1 = lhs(i, j + 3);
+ cblock.packet[1] = pload2(&lhs0, &lhs1);
+ }
+ }
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ }
+
+ pstore<Scalar>(blockAt + rir, blockr.packet[0]);
+ pstore<Scalar>(blockAt + rii, blocki.packet[0]);
+
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if (UseLhs) {
+ blockAt[rir] = lhs(k, i).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(k, i).imag();
+ else
+ blockAt[rii] = lhs(k, i).imag();
+ } else {
+ blockAt[rir] = lhs(i, k).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(i, k).imag();
+ else
+ blockAt[rii] = lhs(i, k).imag();
+ }
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs & rhs packing.
+template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
+struct dhs_pack{
+ EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,4> block;
+
+ if (UseLhs) {
+ bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
+ } else {
+ bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
+ }
+ if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
+ {
+ ptranspose(block);
+ }
+
+ storeBlock<Scalar, Packet, Index>(blockA + ri, block);
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
+ {
+ if (UseLhs) {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ blockA[ri+2] = lhs(j+2, i);
+ blockA[ri+3] = lhs(j+3, i);
+ } else {
+ blockA[ri+0] = lhs(i, j+0);
+ blockA[ri+1] = lhs(i, j+1);
+ blockA[ri+2] = lhs(i, j+2);
+ blockA[ri+3] = lhs(i, j+3);
+ }
+ } else {
+ Packet lhsV;
+ if (UseLhs) {
+ lhsV = lhs.template loadPacket<Packet>(j, i);
+ } else {
+ lhsV = lhs.template loadPacket<Packet>(i, j);
+ }
+ pstore<Scalar>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ if (UseLhs) {
+ blockA[ri] = lhs(k, i);
+ } else {
+ blockA[ri] = lhs(i, k);
+ }
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += vectorSize*offset;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d,2> block;
+ if(StorageOrder == RowMajor)
+ {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
+
+ ptranspose(block);
+ } else {
+ block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
+ block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
+ }
+
+ storeBlock<double, Packet2d, Index>(blockA + ri, block);
+
+ ri += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == RowMajor)
+ {
+ blockA[ri+0] = lhs(j+0, i);
+ blockA[ri+1] = lhs(j+1, i);
+ } else {
+ Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
+ pstore<double>(blockA + ri, lhsV);
+ }
+
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += vectorSize*(stride - offset - depth);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockA[ri] = lhs(k, i);
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for rhs packing, float64 specialization.
+template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
+struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false>
+{
+ EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ Index ri = 0, j = 0;
+
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ if(PanelMode) ri += offset*(2*vectorSize);
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet2d,4> block;
+ if(StorageOrder == ColMajor)
+ {
+ PacketBlock<Packet2d,2> block1, block2;
+ block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
+ block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
+ block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
+ block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
+
+ ptranspose(block1);
+ ptranspose(block2);
+
+ pstore<double>(blockB + ri , block1.packet[0]);
+ pstore<double>(blockB + ri + 2, block2.packet[0]);
+ pstore<double>(blockB + ri + 4, block1.packet[1]);
+ pstore<double>(blockB + ri + 6, block2.packet[1]);
+ } else {
+ block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
+ block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
+ block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
+ block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
+
+ storeBlock<double, Packet2d, Index>(blockB + ri, block);
+ }
+
+ ri += 4*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ if(StorageOrder == ColMajor)
+ {
+ blockB[ri+0] = rhs(i, j+0);
+ blockB[ri+1] = rhs(i, j+1);
+
+ ri += vectorSize;
+
+ blockB[ri+0] = rhs(i, j+2);
+ blockB[ri+1] = rhs(i, j+3);
+ } else {
+ Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
+ pstore<double>(blockB + ri, rhsV);
+
+ ri += vectorSize;
+
+ rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
+ pstore<double>(blockB + ri, rhsV);
+ }
+ ri += vectorSize;
+ }
+
+ if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
+ }
+
+ if (j < cols)
+ {
+ if(PanelMode) ri += offset*(cols - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockB[ri] = rhs(i, k);
+ ri += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for lhs complex packing, float64 specialization.
+template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
+ double* blockAt = reinterpret_cast<double *>(blockA);
+ Index j = 0;
+
+ for(; j + vectorSize <= rows; j+=vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i + vectorSize <= depth; i+=vectorSize)
+ {
+ PacketBlock<Packet,2> blockr, blocki;
+ PacketBlock<PacketC,4> cblock;
+
+ if(StorageOrder == ColMajor)
+ {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
+ blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
+ } else {
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
+
+ cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
+ cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
+ blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
+ }
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ }
+
+ storeBlock<double, Packet, Index>(blockAt + rir, blockr);
+ storeBlock<double, Packet, Index>(blockAt + rii, blocki);
+
+ rir += 2*vectorSize;
+ rii += 2*vectorSize;
+ }
+ for(; i < depth; i++)
+ {
+ PacketBlock<Packet,1> blockr, blocki;
+ PacketBlock<PacketC,2> cblock;
+
+ cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
+ cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ }
+
+ pstore<double>(blockAt + rir, blockr.packet[0]);
+ pstore<double>(blockAt + rii, blocki.packet[0]);
+
+ rir += vectorSize;
+ rii += vectorSize;
+ }
+
+ rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
+ blockAt[rir] = lhs(k, i).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(k, i).imag();
+ else
+ blockAt[rii] = lhs(k, i).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+// General template for rhs complex packing, float64 specialization.
+template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
+struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
+{
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+ {
+ const Index vectorSize = quad_traits<double>::vectorsize;
+ const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
+ Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
+ double* blockBt = reinterpret_cast<double *>(blockB);
+ Index j = 0;
+
+ for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
+ {
+ Index i = 0;
+
+ rii = rir + vectorDelta;
+
+ for(; i < depth; i++)
+ {
+ PacketBlock<PacketC,4> cblock;
+ PacketBlock<Packet,2> blockr, blocki;
+
+ bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
+
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
+ blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
+
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
+ blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
+
+ if(Conjugate)
+ {
+ blocki.packet[0] = -blocki.packet[0];
+ blocki.packet[1] = -blocki.packet[1];
+ }
+
+ storeBlock<double, Packet, Index>(blockBt + rir, blockr);
+ storeBlock<double, Packet, Index>(blockBt + rii, blocki);
+
+ rir += 2*vectorSize;
+ rii += 2*vectorSize;
+ }
+
+ rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
+ }
+
+ if (j < cols)
+ {
+ if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < cols; k++)
+ {
+ blockBt[rir] = rhs(i, k).real();
+
+ if(Conjugate)
+ blockBt[rii] = -rhs(i, k).imag();
+ else
+ blockBt[rii] = rhs(i, k).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+ }
+ }
+ }
+};
+
+/**************
+ * GEMM utils *
+ **************/
+
+// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
+template<typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
+{
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
+}
+
+template<typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
+{
+ if(NegativeAccumulate)
+ {
+ acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ } else {
+ acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ }
+}
+
+template<int N, typename Scalar, typename Packet, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
+{
+ Packet lhsV = pload<Packet>(lhs);
+
+ pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+}
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
+#else
+ Index i = 0;
+ do {
+ lhsV[i] = lhs[i];
+ } while (++i < remaining_rows);
+#endif
+}
+
+template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+{
+ Packet lhsV;
+ loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
+
+ pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+}
+
+// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
+template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
+{
+ pger_common<Packet, false>(accReal, lhsV, rhsV);
+ if(LhsIsReal)
+ {
+ pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ EIGEN_UNUSED_VARIABLE(lhsVi);
+ } else {
+ if (!RhsIsReal) {
+ pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
+ pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhsVi);
+ }
+ pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
+ }
+}
+
+template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
+{
+ Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
+ Packet lhsVi;
+ if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+
+ pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
+}
+
+template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
+{
+#ifdef _ARCH_PWR9
+ lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
+ if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+#else
+ Index i = 0;
+ do {
+ lhsV[i] = lhs_ptr[i];
+ if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
+ } while (++i < remaining_rows);
+ if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+#endif
+}
+
+template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
+{
+ Packet lhsV, lhsVi;
+ loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
+
+ pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
+}
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
+{
+ return ploadu<Packet>(lhs);
+}
+
+// Zero the accumulator on PacketBlock.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+ acc.packet[1] = pset1<Packet>((Scalar)0);
+ acc.packet[2] = pset1<Packet>((Scalar)0);
+ acc.packet[3] = pset1<Packet>((Scalar)0);
+}
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
+{
+ acc.packet[0] = pset1<Packet>((Scalar)0);
+}
+
+// Scale the PacketBlock vectors by alpha.
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+ acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
+ acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
+ acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+{
+ acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+}
+
+// Complex version of PacketBlock scaling.
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
+{
+ bscalec_common<Packet>(cReal, aReal, bReal);
+
+ bscalec_common<Packet>(cImag, aImag, bReal);
+
+ pger_common<Packet, true>(&cReal, bImag, aImag.packet);
+
+ pger_common<Packet, false>(&cImag, bImag, aReal.packet);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
+{
+ acc.packet[0] = pand(acc.packet[0], pMask);
+ acc.packet[1] = pand(acc.packet[1], pMask);
+ acc.packet[2] = pand(acc.packet[2], pMask);
+ acc.packet[3] = pand(acc.packet[3], pMask);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
+{
+ band<Packet>(aReal, pMask);
+ band<Packet>(aImag, pMask);
+
+ bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
+}
+
+// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
+{
+ if (StorageOrder == RowMajor) {
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
+ } else {
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+ }
+}
+
+// An overload of bload when you have a PacketBLock with 8 vectors.
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
+{
+ if (StorageOrder == RowMajor) {
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
+ acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
+ acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
+ acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
+ acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
+ } else {
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
+ acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
+ acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
+ acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
+ acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
+ acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
+ acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
+ }
+}
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
+{
+ acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
+ acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
+}
+
+const static Packet4i mask41 = { -1, 0, 0, 0 };
+const static Packet4i mask42 = { -1, -1, 0, 0 };
+const static Packet4i mask43 = { -1, -1, -1, 0 };
+
+const static Packet2l mask21 = { -1, 0 };
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows)
+{
+ if (remaining_rows == 0) {
+ return pset1<Packet>(float(0.0)); // Not used
+ } else {
+ switch (remaining_rows) {
+ case 1: return Packet(mask41);
+ case 2: return Packet(mask42);
+ default: return Packet(mask43);
+ }
+ }
+}
+
+template<>
+EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
+{
+ if (remaining_rows == 0) {
+ return pset1<Packet2d>(double(0.0)); // Not used
+ } else {
+ return Packet2d(mask21);
+ }
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
+{
+ band<Packet>(accZ, pMask);
+
+ bscale<Packet>(acc, accZ, pAlpha);
+}
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+{
+ pbroadcast4<Packet>(a, a0, a1, a2, a3);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
+{
+ a1 = pload<Packet2d>(a);
+ a3 = pload<Packet2d>(a + 2);
+ a0 = vec_splat(a1, 0);
+ a1 = vec_splat(a1, 1);
+ a2 = vec_splat(a3, 0);
+ a3 = vec_splat(a3, 1);
+}
+
+// PEEL loop factor.
+#define PEEL 7
+
+template<typename Scalar, typename Packet, typename Index>
+EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,1> &accZero,
+ Index remaining_rows,
+ Index remaining_cols)
+{
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,1> accZero;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+ }
+ for(; k < depth; k++)
+ {
+ Packet rhsV[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr[0]);
+ pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += remaining_cols;
+ }
+
+ accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
+ for(Index i = 0; i < remaining_rows; i++) {
+ res(row + i, col) += accZero.packet[0][i];
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows>
+EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
+ const Scalar* &lhs_ptr,
+ const Scalar* &rhs_ptr,
+ PacketBlock<Packet,4> &accZero,
+ Index remaining_rows)
+{
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
+ PacketBlock<Packet,4> accZero, acc;
+
+ bsetzero<Scalar, Packet>(accZero);
+
+ Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL <= remaining_depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ for (int l = 0; l < PEEL; l++) {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ }
+
+ if ((remaining_depth == depth) && (rows >= accCols))
+ {
+ for(Index j = 0; j < 4; j++) {
+ acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
+ }
+ bscale<Packet>(acc, accZero, pAlpha, pMask);
+ res.template storePacketBlock<Packet,4>(row, col, acc);
+ } else {
+ for(; k < depth; k++)
+ {
+ Packet rhsV[4];
+ pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ lhs_ptr += remaining_rows;
+ rhs_ptr += accRows;
+ }
+
+ for(Index j = 0; j < 4; j++) {
+ accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
+ }
+ for(Index j = 0; j < 4; j++) {
+ for(Index i = 0; i < remaining_rows; i++) {
+ res(row + i, col + j) += accZero.packet[j][i];
+ }
+ }
+ }
+}
+
+#define MICRO_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_UNROLL_WORK(func, func2, peel) \
+ MICRO_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
+ func(4,peel) func(5,peel) func(6,peel) func(7,peel)
+
+#define MICRO_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
+ lhs_ptr##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ }
+
+#define MICRO_WORK_ONE(iter, peel) \
+ if (unroll_factor > iter) { \
+ pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
+ }
+
+#define MICRO_TYPE_PEEL4(func, func2, peel) \
+ if (PEEL > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ MICRO_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_TYPE_PEEL1(func, func2, peel) \
+ if (PEEL > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
+ MICRO_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ func(func1,func2,0); func(func1,func2,1); \
+ func(func1,func2,2); func(func1,func2,3); \
+ func(func1,func2,4); func(func1,func2,5); \
+ func(func1,func2,6); func(func1,func2,7); \
+ func(func1,func2,8); func(func1,func2,9);
+
+#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
+ Packet rhsV0[M]; \
+ func(func1,func2,0);
+
+#define MICRO_ONE_PEEL4 \
+ MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += (accRows * PEEL);
+
+#define MICRO_ONE4 \
+ MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += accRows;
+
+#define MICRO_ONE_PEEL1 \
+ MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += (remaining_cols * PEEL);
+
+#define MICRO_ONE1 \
+ MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
+ rhs_ptr += remaining_cols;
+
+#define MICRO_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzero<Scalar, Packet>(accZero##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ }
+
+#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
+
+#define MICRO_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
+
+#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
+
+#define MICRO_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
+ }
+
+#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
+
+#define MICRO_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
+ acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
+ acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
+ bscale<Packet>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
+ }
+
+#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
+
+#define MICRO_COL_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
+ bscale<Packet>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
+ }
+
+#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
+ PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,4> acc;
+
+ MICRO_SRC_PTR
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCH
+ MICRO_ONE_PEEL4
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_ONE4
+ }
+ MICRO_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
+ PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,1> acc;
+
+ MICRO_SRC_PTR
+ MICRO_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL <= depth; k+= PEEL)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCH
+ MICRO_ONE_PEEL1
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_ONE1
+ }
+ MICRO_COL_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha)
+{
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows) {
+ gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_UNROLL
+}
+
+/****************
+ * GEMM kernels *
+ * **************/
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_UNROLL 6
+ while(row + MAX_UNROLL*accCols <= rows) {
+ gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_UNROLL > 7
+ case 7:
+ gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 6
+ case 6:
+ gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 5
+ case 5:
+ gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 4
+ case 4:
+ gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 3
+ case 3:
+ gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 2
+ case 2:
+ gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_UNROLL > 1
+ case 1:
+ gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
+
+ if (remaining_rows > 0)
+ {
+ gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#define accColsC (accCols / 2)
+#define advanceRows ((LhsIsReal) ? 1 : 2)
+#define advanceCols ((RhsIsReal) ? 1 : 2)
+
+// PEEL_COMPLEX loop factor.
+#define PEEL_COMPLEX 3
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL(
+ const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
+ const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
+ PacketBlock<Packet,1> &accReal, PacketBlock<Packet,1> &accImag,
+ Index remaining_rows,
+ Index remaining_cols)
+{
+ Packet rhsV[1], rhsVi[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
+ if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
+ pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ rhs_ptr_real += remaining_cols;
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
+ const Scalar* lhs_ptr_imag;
+ if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ PacketBlock<Packet,1> accReal, accImag;
+ PacketBlock<Packet,1> taccReal, taccImag;
+ PacketBlock<Packetc,1> acc0, acc1;
+
+ bsetzero<Scalar, Packet>(accReal);
+ bsetzero<Scalar, Packet>(accImag);
+
+ Index remaining_depth = (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
+ if(!LhsIsReal) {
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
+ }
+ for (int l = 0; l < PEEL_COMPLEX; l++) {
+ MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
+ }
+
+ for(; k < depth; k++)
+ {
+ Packet rhsV[1], rhsVi[1];
+ rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
+ if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
+ pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ rhs_ptr_real += remaining_cols;
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+ }
+
+ bscalec<Packet,1>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+
+ if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
+ {
+ res(row + 0, col + 0) += pfirst<Packetc>(acc0.packet[0]);
+ } else {
+ acc0.packet[0] += res.template loadPacket<Packetc>(row + 0, col + 0);
+ res.template storePacketBlock<Packetc,1>(row + 0, col + 0, acc0);
+ if(remaining_rows > accColsC) {
+ res(row + accColsC, col + 0) += pfirst<Packetc>(acc1.packet[0]);
+ }
+ }
+}
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
+ const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
+ const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
+ PacketBlock<Packet,4> &accReal, PacketBlock<Packet,4> &accImag,
+ Index remaining_rows)
+{
+ Packet rhsV[4], rhsVi[4];
+ pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ rhs_ptr_real += accRows;
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
+ const Scalar* lhs_ptr_imag;
+ if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+ PacketBlock<Packet,4> accReal, accImag;
+ PacketBlock<Packet,4> taccReal, taccImag;
+ PacketBlock<Packetc,4> acc0, acc1;
+ PacketBlock<Packetc,8> tRes;
+
+ bsetzero<Scalar, Packet>(accReal);
+ bsetzero<Scalar, Packet>(accImag);
+
+ Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
+ if(!LhsIsReal) {
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
+ }
+ for (int l = 0; l < PEEL_COMPLEX; l++) {
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
+ }
+ }
+ for(; k < remaining_depth; k++)
+ {
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
+ }
+
+ if ((remaining_depth == depth) && (rows >= accCols))
+ {
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row, col);
+ bscalec<Packet>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1);
+ res.template storePacketBlock<Packetc,4>(row + 0, col, acc0);
+ res.template storePacketBlock<Packetc,4>(row + accColsC, col, acc1);
+ } else {
+ for(; k < depth; k++)
+ {
+ Packet rhsV[4], rhsVi[4];
+ pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
+ lhs_ptr_real += remaining_rows;
+ if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
+ rhs_ptr_real += accRows;
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+ }
+
+ bscalec<Packet,4>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+
+ if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
+ {
+ for(Index j = 0; j < 4; j++) {
+ res(row + 0, col + j) += pfirst<Packetc>(acc0.packet[j]);
+ }
+ } else {
+ for(Index j = 0; j < 4; j++) {
+ PacketBlock<Packetc,1> acc2;
+ acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, col + j) + acc0.packet[j];
+ res.template storePacketBlock<Packetc,1>(row + 0, col + j, acc2);
+ if(remaining_rows > accColsC) {
+ res(row + accColsC, col + j) += pfirst<Packetc>(acc1.packet[j]);
+ }
+ }
+ }
+ }
+}
+
+#define MICRO_COMPLEX_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4)
+
+#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ MICRO_COMPLEX_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel)
+
+#define MICRO_COMPLEX_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
+ lhs_ptr_real##iter += accCols; \
+ if(!LhsIsReal) { \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
+ lhs_ptr_imag##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ }
+
+#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
+ if (unroll_factor > iter) { \
+ pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \
+ if (unroll_factor > iter) { \
+ pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
+ if (PEEL_COMPLEX > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ pbroadcast4_old<Packet>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ if(!RhsIsReal) { \
+ pbroadcast4_old<Packet>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \
+ if (PEEL_COMPLEX > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ rhsV##peel[0] = pset1<Packet>(rhs_ptr_real[remaining_cols * peel]); \
+ if(!RhsIsReal) { \
+ rhsVi##peel[0] = pset1<Packet>(rhs_ptr_imag[remaining_cols * peel]); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \
+ func(func1,func2,0); func(func1,func2,1); \
+ func(func1,func2,2); func(func1,func2,3); \
+ func(func1,func2,4); func(func1,func2,5); \
+ func(func1,func2,6); func(func1,func2,7); \
+ func(func1,func2,8); func(func1,func2,9);
+
+#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
+ Packet rhsV0[M], rhsVi0[M];\
+ func(func1,func2,0);
+
+#define MICRO_COMPLEX_ONE_PEEL4 \
+ MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += (accRows * PEEL_COMPLEX); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
+
+#define MICRO_COMPLEX_ONE4 \
+ MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += accRows; \
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+
+#define MICRO_COMPLEX_ONE_PEEL1 \
+ MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \
+ if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX);
+
+#define MICRO_COMPLEX_ONE1 \
+ MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
+ rhs_ptr_real += remaining_cols; \
+ if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+
+#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzero<Scalar, Packet>(accReal##iter); \
+ bsetzero<Scalar, Packet>(accImag##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accReal##iter); \
+ EIGEN_UNUSED_VARIABLE(accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
+
+#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
+ if(!LhsIsReal) { \
+ lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ }
+
+#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
+
+#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
+ if(!LhsIsReal) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
+ } \
+ }
+
+#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
+
+#define MICRO_COMPLEX_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
+ bscalec<Packet,4>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
+ res.template storePacketBlock<Packetc,4>(row + iter*accCols + 0, col, acc0); \
+ res.template storePacketBlock<Packetc,4>(row + iter*accCols + accColsC, col, acc1); \
+ }
+
+#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
+
+#define MICRO_COMPLEX_COL_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
+ bscalec<Packet,1>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
+ res.template storePacketBlock<Packetc,1>(row + iter*accCols + 0, col, acc0); \
+ res.template storePacketBlock<Packetc,1>(row + iter*accCols + accColsC, col, acc1); \
+ }
+
+#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + accRows*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ PacketBlock<Packet,4> accReal0, accImag0, accReal1, accImag1;
+ PacketBlock<Packet,4> accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,4> accReal4, accImag4;
+ PacketBlock<Packet,4> taccReal, taccImag;
+ PacketBlock<Packetc,4> acc0, acc1;
+ PacketBlock<Packetc,8> tRes;
+
+ MICRO_COMPLEX_SRC_PTR
+ MICRO_COMPLEX_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_PREFETCH
+ MICRO_COMPLEX_ONE_PEEL4
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_ONE4
+ }
+ MICRO_COMPLEX_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + remaining_cols*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ PacketBlock<Packet,1> accReal0, accImag0, accReal1, accImag1;
+ PacketBlock<Packet,1> accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,1> accReal4, accImag4;
+ PacketBlock<Packet,1> taccReal, taccImag;
+ PacketBlock<Packetc,1> acc0, acc1;
+ PacketBlock<Packetc,2> tRes;
+
+ MICRO_COMPLEX_SRC_PTR
+ MICRO_COMPLEX_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_PREFETCH
+ MICRO_COMPLEX_ONE_PEEL1
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_ONE1
+ }
+ MICRO_COMPLEX_COL_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+#define MAX_COMPLEX_UNROLL 3
+ while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_col_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_UNROLL
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ const Scalar* blockA = (Scalar *) blockAc;
+ const Scalar* blockB = (Scalar *) blockBc;
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_COMPLEX_UNROLL 3
+ while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
+
+ if (remaining_rows > 0)
+ {
+ gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#undef accColsC
+#undef advanceCols
+#undef advanceRows
+
+/************************************
+ * ppc64le template specializations *
+ * **********************************/
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+#endif
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+#endif
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
+void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
+ pack(blockA, lhs, depth, rows, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+{
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
+};
+
+template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
+void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
+ ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
+{
+ dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
+ pack(blockB, rhs, depth, cols, stride, offset);
+}
+
+// ********* gebp specializations *********
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<float>::vectortype Packet;
+ typedef typename quad_traits<float>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, float alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef Packet4f Packet;
+ typedef Packet2cf Packetc;
+ typedef Packet4f RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
+ Index rows, Index depth, Index cols, std::complex<float> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<float>::rows;
+ const Index accCols = quad_traits<float>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef typename quad_traits<double>::vectortype Packet;
+ typedef typename quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, double alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
+
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+{
+ typedef quad_traits<double>::vectortype Packet;
+ typedef Packet1cd Packetc;
+ typedef quad_traits<double>::rhstype RhsPacket;
+
+ void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
+};
+
+template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
+void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
+ ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
+ Index rows, Index depth, Index cols, std::complex<double> alpha,
+ Index strideA, Index strideB, Index offsetA, Index offsetB)
+ {
+ const Index accRows = quad_traits<double>::rows;
+ const Index accCols = quad_traits<double>::size;
+ void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
+ #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ }
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
new file mode 100644
index 000000000..33d543494
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -0,0 +1,221 @@
+//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines
+#ifdef EIGEN_POWER_USE_PREFETCH
+#define EIGEN_POWER_PREFETCH(p) prefetch(p)
+#else
+#define EIGEN_POWER_PREFETCH(p)
+#endif
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
+EIGEN_STRONG_INLINE void gemm_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlpha);
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask);
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlpha);
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index remaining_rows,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask);
+
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index rows,
+ Index col,
+ Index remaining_cols,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag);
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
+
+template<typename Packet>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
+
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
+
+const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
+ 16, 17, 18, 19,
+ 4, 5, 6, 7,
+ 20, 21, 22, 23};
+
+const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
+ 24, 25, 26, 27,
+ 12, 13, 14, 15,
+ 28, 29, 30, 31};
+//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
+const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+
+//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
+const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
+ 24, 25, 26, 27, 28, 29, 30, 31};
+
+
+// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+ acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
+ acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
+ acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
+ acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
+ acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
+ acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+}
+
+template<typename Packet, typename Packetc>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+{
+ bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+
+ acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
+ acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
+ acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
+{
+ acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
+
+ acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+}
+
+// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs)
+{
+ return ploadu<Packet>(rhs);
+}
+
+} // end namespace internal
+} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
new file mode 100644
index 000000000..6540c6fa6
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -0,0 +1,629 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
+// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+
+#pragma GCC target("cpu=power10")
+
+#ifdef __has_builtin
+#if !__has_builtin(__builtin_vsx_assemble_pair)
+#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
+#endif
+#endif
+
+namespace Eigen {
+
+namespace internal {
+
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
+{
+ __builtin_mma_xxsetaccz(acc);
+}
+
+template<typename DataMapper, typename Index, typename Packet, const Index accCols>
+EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
+{
+ PacketBlock<Packet, 4> result;
+ __builtin_mma_disassemble_acc(&result.packet, acc);
+
+ PacketBlock<Packet, 4> tRes;
+ bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
+
+ bscale<Packet>(tRes, result, alpha);
+
+ data.template storePacketBlock<Packet, 4>(i, j, tRes);
+}
+
+template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
+EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
+{
+ PacketBlock<Packet, 4> resultReal, resultImag;
+ __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
+ __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
+
+ PacketBlock<Packetc, 8> tRes;
+ bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
+
+ PacketBlock<Packet,4> taccReal, taccImag;
+ bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
+
+ PacketBlock<Packetc, 4> acc1, acc2;
+ bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
+
+ data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
+ data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
+}
+
+// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
+{
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
+{
+ __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
+{
+ if(NegativeAccumulate)
+ {
+ __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
+ } else {
+ __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
+ }
+}
+
+template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
+EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
+{
+ // Just for compilation
+}
+
+template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
+{
+ pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
+ if(LhsIsReal) {
+ pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
+ } else {
+ if(!RhsIsReal) {
+ pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
+ pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhsVi);
+ }
+ pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
+ }
+}
+
+// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Scalar, typename Packet>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
+{
+ rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
+{
+ rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
+ rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
+{
+#if EIGEN_COMP_LLVM
+ __builtin_vsx_assemble_pair(&rhsV,
+ (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
+ (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
+#else
+ __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
+#endif
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
+{
+ // Just for compilation
+}
+
+// PEEL_MMA loop factor.
+#define PEEL_MMA 7
+
+#define MICRO_MMA_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_MMA_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
+ lhs_ptr##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ }
+
+#define MICRO_MMA_WORK_ONE(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
+ }
+
+#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
+ if (PEEL_MMA > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
+ ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
+ MICRO_MMA_UNROLL(func2); \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
+ func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ }
+
+#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
+
+#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+ type rhsV0; \
+ MICRO_MMA_TYPE_PEEL(func,func2,type,0);
+
+#define MICRO_MMA_ONE_PEEL \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr += (accRows * PEEL_MMA);
+
+#define MICRO_MMA_ONE \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr += accRows;
+
+#define MICRO_MMA_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
+ }
+
+#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
+
+#define MICRO_MMA_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
+
+#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
+
+#define MICRO_MMA_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
+ }
+
+#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
+
+#define MICRO_MMA_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
+ }
+
+#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index& row,
+ Index col,
+ const Packet& pAlpha)
+{
+ const Scalar* rhs_ptr = rhs_base;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
+ __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+
+ MICRO_MMA_SRC_PTR
+ MICRO_MMA_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_MMA_PREFETCH
+ MICRO_MMA_ONE_PEEL
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_MMA_ONE
+ }
+ MICRO_MMA_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ Index row = 0;
+#define MAX_MMA_UNROLL 7
+ while(row + MAX_MMA_UNROLL*accCols <= rows) {
+ gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_MMA_UNROLL > 7
+ case 7:
+ gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 6
+ case 6:
+ gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 5
+ case 5:
+ gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 4
+ case 4:
+ gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 3
+ case 3:
+ gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 2
+ case 2:
+ gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+#if MAX_MMA_UNROLL > 1
+ case 1:
+ gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_MMA_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
+
+ if (remaining_rows > 0)
+ {
+ gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#define accColsC (accCols / 2)
+#define advanceRows ((LhsIsReal) ? 1 : 2)
+#define advanceCols ((RhsIsReal) ? 1 : 2)
+
+// PEEL_COMPLEX_MMA loop factor.
+#define PEEL_COMPLEX_MMA 7
+
+#define MICRO_COMPLEX_MMA_UNROLL(func) \
+ func(0) func(1) func(2) func(3) func(4)
+
+#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
+ lhs_ptr_real##iter += accCols; \
+ if(!LhsIsReal) { \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
+ lhs_ptr_imag##iter += accCols; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
+ if (PEEL_COMPLEX_MMA > peel) { \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
+ if(!RhsIsReal) { \
+ ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ } \
+ MICRO_COMPLEX_MMA_UNROLL(func2); \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
+ type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
+
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+ type rhsV0, rhsVi0; \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
+
+#define MICRO_COMPLEX_MMA_ONE_PEEL \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
+
+#define MICRO_COMPLEX_MMA_ONE \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
+ } else { \
+ MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
+ } \
+ rhs_ptr_real += accRows; \
+ if(!RhsIsReal) rhs_ptr_imag += accRows;
+
+#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
+ bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accReal##iter); \
+ EIGEN_UNUSED_VARIABLE(accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
+
+#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
+ if(!LhsIsReal) { \
+ lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
+
+#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
+ if(!LhsIsReal) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
+ } \
+ }
+
+#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
+
+#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
+ if (unroll_factor > iter) { \
+ storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
+ }
+
+#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
+
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index& row,
+ Index col,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag)
+{
+ const Scalar* rhs_ptr_real = rhs_base;
+ const Scalar* rhs_ptr_imag;
+ if(!RhsIsReal) {
+ rhs_ptr_imag = rhs_base + accRows*strideB;
+ } else {
+ EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ }
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
+ const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
+ __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
+
+ MICRO_COMPLEX_MMA_SRC_PTR
+ MICRO_COMPLEX_MMA_DST_PTR
+
+ Index k = 0;
+ for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ MICRO_COMPLEX_MMA_PREFETCH
+ MICRO_COMPLEX_MMA_ONE_PEEL
+ }
+ for(; k < depth; k++)
+ {
+ MICRO_COMPLEX_MMA_ONE
+ }
+ MICRO_COMPLEX_MMA_STORE
+
+ row += unroll_factor*accCols;
+}
+
+template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
+ const Index remaining_cols = cols % accRows;
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+
+ const Scalar* blockA = (Scalar *) blockAc;
+ const Scalar* blockB = (Scalar *) blockBc;
+
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA;
+ Index row = 0;
+
+#define MAX_COMPLEX_MMA_UNROLL 4
+ while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_MMA_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_MMA_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
+ }
+
+ if(remaining_cols > 0)
+ {
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
+ const Scalar* lhs_base = blockA;
+
+ for(; col < cols; col++)
+ {
+ Index row = 0;
+
+ gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
+
+ if (remaining_rows > 0)
+ {
+ gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
+ }
+ rhs_base++;
+ }
+ }
+}
+
+#undef accColsC
+#undef advanceRows
+#undef advanceCols
+
+#pragma GCC reset_options
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index b3f1ea199..2a440545b 100755
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -22,31 +22,38 @@ namespace internal {
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
-#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#define EIGEN_HAS_SINGLE_INSTRUCTION_CJMADD
-#endif
-
// NOTE Altivec has 32 registers, but Eigen only accepts a value of 8 or 16
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
-typedef __vector float Packet4f;
-typedef __vector int Packet4i;
-typedef __vector unsigned int Packet4ui;
-typedef __vector __bool int Packet4bi;
-typedef __vector short int Packet8i;
-typedef __vector unsigned char Packet16uc;
+typedef __vector float Packet4f;
+typedef __vector int Packet4i;
+typedef __vector unsigned int Packet4ui;
+typedef __vector __bool int Packet4bi;
+typedef __vector short int Packet8s;
+typedef __vector unsigned short int Packet8us;
+typedef __vector signed char Packet16c;
+typedef __vector unsigned char Packet16uc;
+typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
// We don't want to write the same code all the time, but we need to reuse the constants
// and it doesn't really work to declare them global, so we define macros instead
-
#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \
- Packet4f p4f_##NAME = reinterpret_cast<Packet4f>(vec_splat_s32(X))
+ Packet4f p4f_##NAME = {X, X, X, X}
#define _EIGEN_DECLARE_CONST_FAST_Packet4i(NAME,X) \
Packet4i p4i_##NAME = vec_splat_s32(X)
+#define _EIGEN_DECLARE_CONST_FAST_Packet4ui(NAME,X) \
+ Packet4ui p4ui_##NAME = {X, X, X, X}
+
+#define _EIGEN_DECLARE_CONST_FAST_Packet8us(NAME,X) \
+ Packet8us p8us_##NAME = {X, X, X, X, X, X, X, X}
+
+#define _EIGEN_DECLARE_CONST_FAST_Packet16uc(NAME,X) \
+ Packet16uc p16uc_##NAME = {X, X, X, X, X, X, X, X, X, X, X, X, X, X, X, X}
+
#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \
Packet4f p4f_##NAME = pset1<Packet4f>(X)
@@ -64,7 +71,7 @@ typedef __vector unsigned char Packet16uc;
#define DST_CHAN 1
#define DST_CTRL(size, count, stride) (((size) << 24) | ((count) << 16) | (stride))
-
+#define __UNPACK_TYPE__(PACKETNAME) typename unpacket_traits<PACKETNAME>::type
// These constants are endian-agnostic
static _EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0}
@@ -72,25 +79,36 @@ static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(ONE,1); //{ 1, 1, 1, 1}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16}
static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
+static _EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u);
+static _EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu);
+static _EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1}
+static _EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1);
static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000}
#ifndef __VSX__
static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0}
#endif
-static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
-static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
+static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 };
+static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 };
+static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
+static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 };
+
+static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15};
+static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15};
static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 };
-static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
+static Packet16uc p16uc_REVERSE16 = { 14,15, 12,13, 10,11, 8,9, 6,7, 4,5, 2,3, 0,1 };
+static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 };
-// Mask alignment
-#ifdef __PPC64__
-#define _EIGEN_MASK_ALIGNMENT 0xfffffffffffffff0
-#else
-#define _EIGEN_MASK_ALIGNMENT 0xfffffff0
-#endif
+static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
+static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
+static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
+static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 };
+static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 };
-#define _EIGEN_ALIGNED_PTR(x) ((std::ptrdiff_t)(x) & _EIGEN_MASK_ALIGNMENT)
+static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
// Handle endianness properly while loading constants
// Define global static constants:
@@ -103,7 +121,7 @@ static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4u
static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16};
#else
-static Packet16uc p16uc_FORWARD = p16uc_REVERSE32;
+static Packet16uc p16uc_FORWARD = p16uc_REVERSE32;
static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 1), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 };
static Packet16uc p16uc_PSET32_WEVEN = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
@@ -129,27 +147,27 @@ static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_PSET64_HI, p16uc_PSET64_L
#define EIGEN_PPC_PREFETCH(ADDR) asm( " dcbt [%[addr]]\n" :: [addr] "r" (ADDR) : "cc" );
#endif
-template<> struct packet_traits<float> : default_packet_traits
-{
+template <>
+struct packet_traits<float> : default_packet_traits {
typedef Packet4f type;
typedef Packet4f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
- size=4,
+ size = 4,
HasHalfPacket = 1,
- HasAdd = 1,
- HasSub = 1,
- HasMul = 1,
- HasDiv = 1,
- HasMin = 1,
- HasMax = 1,
- HasAbs = 1,
- HasSin = 0,
- HasCos = 0,
- HasLog = 0,
- HasExp = 1,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasAbs = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
#ifdef __VSX__
HasSqrt = 1,
#if !EIGEN_COMP_CLANG
@@ -160,16 +178,62 @@ template<> struct packet_traits<float> : default_packet_traits
#else
HasSqrt = 0,
HasRsqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
#endif
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasNegate = 1,
HasBlend = 1
};
};
-template<> struct packet_traits<int> : default_packet_traits
-{
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasAbs = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasExp = 1,
+#ifdef __VSX__
+ HasSqrt = 1,
+#if !EIGEN_COMP_CLANG
+ HasRsqrt = 1,
+#else
+ HasRsqrt = 0,
+#endif
+#else
+ HasSqrt = 0,
+ HasRsqrt = 0,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+#endif
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1,
+ HasNegate = 1,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<int> : default_packet_traits {
typedef Packet4i type;
typedef Packet4i half;
enum {
@@ -178,6 +242,25 @@ template<> struct packet_traits<int> : default_packet_traits
size = 4,
HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasShift = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<short int> : default_packet_traits {
+ typedef Packet8s type;
+ typedef Packet8s half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
HasAdd = 1,
HasSub = 1,
HasMul = 1,
@@ -186,9 +269,116 @@ template<> struct packet_traits<int> : default_packet_traits
};
};
+template <>
+struct packet_traits<unsigned short int> : default_packet_traits {
+ typedef Packet8us type;
+ typedef Packet8us half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<signed char> : default_packet_traits {
+ typedef Packet16c type;
+ typedef Packet16c half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template <>
+struct packet_traits<unsigned char> : default_packet_traits {
+ typedef Packet16uc type;
+ typedef Packet16uc half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 0,
+
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 0,
+ HasBlend = 1
+ };
+};
+
+template<> struct unpacket_traits<Packet4f>
+{
+ typedef float type;
+ typedef Packet4f half;
+ typedef Packet4i integer_packet;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet4i>
+{
+ typedef int type;
+ typedef Packet4i half;
+ enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8s>
+{
+ typedef short int type;
+ typedef Packet8s half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet8us>
+{
+ typedef unsigned short int type;
+ typedef Packet8us half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+
+template<> struct unpacket_traits<Packet16c>
+{
+ typedef signed char type;
+ typedef Packet16c half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+template<> struct unpacket_traits<Packet16uc>
+{
+ typedef unsigned char type;
+ typedef Packet16uc half;
+ enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
-template<> struct unpacket_traits<Packet4f> { typedef float type; enum {size=4, alignment=Aligned16}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet4i> { typedef int type; enum {size=4, alignment=Aligned16}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet8bf>
+{
+ typedef bfloat16 type;
+ typedef Packet8bf half;
+ enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
+};
+inline std::ostream & operator <<(std::ostream & s, const Packet16c & v)
+{
+ union {
+ Packet16c v;
+ signed char n[16];
+ } vt;
+ vt.v = v;
+ for (int i=0; i< 16; i++)
+ s << vt.n[i] << ", ";
+ return s;
+}
inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v)
{
@@ -198,7 +388,7 @@ inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v)
} vt;
vt.v = v;
for (int i=0; i< 16; i++)
- s << (int)vt.n[i] << ", ";
+ s << vt.n[i] << ", ";
return s;
}
@@ -235,122 +425,366 @@ inline std::ostream & operator <<(std::ostream & s, const Packet4ui & v)
return s;
}
-// Need to define them first or we get specialization after instantiation errors
-template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet pload_common(const __UNPACK_TYPE__(Packet)* from)
{
+ // some versions of GCC throw "unused-but-set-parameter".
+ // ignoring these warnings for now.
+ EIGEN_UNUSED_VARIABLE(from);
EIGEN_DEBUG_ALIGNED_LOAD
#ifdef __VSX__
- return vec_vsx_ld(0, from);
+ return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
#else
return vec_ld(0, from);
#endif
}
+// Need to define them first or we get specialization after instantiation errors
+template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from)
+{
+ return pload_common<Packet4f>(from);
+}
+
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
-#ifdef __VSX__
- return vec_vsx_ld(0, from);
-#else
- return vec_ld(0, from);
-#endif
+ return pload_common<Packet4i>(from);
}
-template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE Packet8s pload<Packet8s>(const short int* from)
+{
+ return pload_common<Packet8s>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pload<Packet8us>(const unsigned short int* from)
{
+ return pload_common<Packet8us>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pload<Packet16c>(const signed char* from)
+{
+ return pload_common<Packet16c>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const unsigned char* from)
+{
+ return pload_common<Packet16uc>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from)
+{
+ return pload_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){
+ // some versions of GCC throw "unused-but-set-parameter" (float *to).
+ // ignoring these warnings for now.
+ EIGEN_UNUSED_VARIABLE(to);
EIGEN_DEBUG_ALIGNED_STORE
#ifdef __VSX__
- vec_vsx_st(from, 0, to);
+ vec_xst(from, 0, to);
#else
vec_st(from, 0, to);
#endif
}
+template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from)
+{
+ pstore_common<Packet4f>(to, from);
+}
+
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
-#ifdef __VSX__
- vec_vsx_st(from, 0, to);
-#else
- vec_st(from, 0, to);
-#endif
+ pstore_common<Packet4i>(to, from);
}
-template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) {
- Packet4f v = {from, from, from, from};
+template<> EIGEN_STRONG_INLINE void pstore<short int>(short int* to, const Packet8s& from)
+{
+ pstore_common<Packet8s>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short int* to, const Packet8us& from)
+{
+ pstore_common<Packet8us>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
+{
+ pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<signed char>(signed char* to, const Packet16c& from)
+{
+ pstore_common<Packet16c>(to, from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<unsigned char>(unsigned char* to, const Packet16uc& from)
+{
+ pstore_common<Packet16uc>(to, from);
+}
+
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size4(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from};
return v;
}
-template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) {
- Packet4i v = {from, from, from, from};
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size8(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from, from, from, from, from};
return v;
}
-template<> EIGEN_STRONG_INLINE void
-pbroadcast4<Packet4f>(const float *a,
- Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+
+template<typename Packet>
+EIGEN_STRONG_INLINE Packet pset1_size16(const __UNPACK_TYPE__(Packet)& from)
+{
+ Packet v = {from, from, from, from, from, from, from, from, from, from, from, from, from, from, from, from};
+ return v;
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) {
+ return pset1_size4<Packet4f>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) {
+ return pset1_size4<Packet4i>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const short int& from) {
+ return pset1_size8<Packet8s>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pset1<Packet8us>(const unsigned short int& from) {
+ return pset1_size8<Packet8us>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pset1<Packet16c>(const signed char& from) {
+ return pset1_size16<Packet16c>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pset1<Packet16uc>(const unsigned char& from) {
+ return pset1_size16<Packet16uc>(from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) {
+ return reinterpret_cast<Packet4f>(pset1<Packet4i>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return pset1_size8<Packet8us>(reinterpret_cast<const unsigned short int&>(from));
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE void
+pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
{
- a3 = pload<Packet4f>(a);
+ a3 = pload<Packet>(a);
a0 = vec_splat(a3, 0);
a1 = vec_splat(a3, 1);
a2 = vec_splat(a3, 2);
a3 = vec_splat(a3, 3);
}
+
+template<> EIGEN_STRONG_INLINE void
+pbroadcast4<Packet4f>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ pbroadcast4_common<Packet4f>(a, a0, a1, a2, a3);
+}
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet4i>(const int *a,
Packet4i& a0, Packet4i& a1, Packet4i& a2, Packet4i& a3)
{
- a3 = pload<Packet4i>(a);
- a0 = vec_splat(a3, 0);
- a1 = vec_splat(a3, 1);
- a2 = vec_splat(a3, 2);
- a3 = vec_splat(a3, 3);
+ pbroadcast4_common<Packet4i>(a, a0, a1, a2, a3);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_common(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ return pload<Packet>(a);
}
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{
- float EIGEN_ALIGN16 af[4];
- af[0] = from[0*stride];
- af[1] = from[1*stride];
- af[2] = from[2*stride];
- af[3] = from[3*stride];
- return pload<Packet4f>(af);
+ return pgather_common<Packet4f>(from, stride);
}
+
template<> EIGEN_DEVICE_FUNC inline Packet4i pgather<int, Packet4i>(const int* from, Index stride)
{
- int EIGEN_ALIGN16 ai[4];
- ai[0] = from[0*stride];
- ai[1] = from[1*stride];
- ai[2] = from[2*stride];
- ai[3] = from[3*stride];
- return pload<Packet4i>(ai);
+ return pgather_common<Packet4i>(from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_size8(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ a[4] = from[4*stride];
+ a[5] = from[5*stride];
+ a[6] = from[6*stride];
+ a[7] = from[7*stride];
+ return pload<Packet>(a);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8s pgather<short int, Packet8s>(const short int* from, Index stride)
+{
+ return pgather_size8<Packet8s>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8us pgather<unsigned short int, Packet8us>(const unsigned short int* from, Index stride)
+{
+ return pgather_size8<Packet8us>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ return pgather_size8<Packet8bf>(from, stride);
}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather_size16(const __UNPACK_TYPE__(Packet)* from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
+ a[0] = from[0*stride];
+ a[1] = from[1*stride];
+ a[2] = from[2*stride];
+ a[3] = from[3*stride];
+ a[4] = from[4*stride];
+ a[5] = from[5*stride];
+ a[6] = from[6*stride];
+ a[7] = from[7*stride];
+ a[8] = from[8*stride];
+ a[9] = from[9*stride];
+ a[10] = from[10*stride];
+ a[11] = from[11*stride];
+ a[12] = from[12*stride];
+ a[13] = from[13*stride];
+ a[14] = from[14*stride];
+ a[15] = from[15*stride];
+ return pload<Packet>(a);
+}
+
+
+template<> EIGEN_DEVICE_FUNC inline Packet16c pgather<signed char, Packet16c>(const signed char* from, Index stride)
+{
+ return pgather_size16<Packet16c>(from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline Packet16uc pgather<unsigned char, Packet16uc>(const unsigned char* from, Index stride)
+{
+ return pgather_size16<Packet16uc>(from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size4(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+}
+
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet4f>(float* to, const Packet4f& from, Index stride)
{
- float EIGEN_ALIGN16 af[4];
- pstore<float>(af, from);
- to[0*stride] = af[0];
- to[1*stride] = af[1];
- to[2*stride] = af[2];
- to[3*stride] = af[3];
+ pscatter_size4<Packet4f>(to, from, stride);
}
+
template<> EIGEN_DEVICE_FUNC inline void pscatter<int, Packet4i>(int* to, const Packet4i& from, Index stride)
{
- int EIGEN_ALIGN16 ai[4];
- pstore<int>((int *)ai, from);
- to[0*stride] = ai[0];
- to[1*stride] = ai[1];
- to[2*stride] = ai[2];
- to[3*stride] = ai[3];
+ pscatter_size4<Packet4i>(to, from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size8(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+ to[4*stride] = a[4];
+ to[5*stride] = a[5];
+ to[6*stride] = a[6];
+ to[7*stride] = a[7];
}
-template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return pset1<Packet4f>(a) + p4f_COUNTDOWN; }
-template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return pset1<Packet4i>(a) + p4i_COUNTDOWN; }
-template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return a + b; }
-template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return a + b; }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<short int, Packet8s>(short int* to, const Packet8s& from, Index stride)
+{
+ pscatter_size8<Packet8s>(to, from, stride);
+}
-template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return a - b; }
-template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return a - b; }
+template<> EIGEN_DEVICE_FUNC inline void pscatter<unsigned short int, Packet8us>(unsigned short int* to, const Packet8us& from, Index stride)
+{
+ pscatter_size8<Packet8us>(to, from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ pscatter_size8<Packet8bf>(to, from, stride);
+}
+
+template<typename Packet> EIGEN_DEVICE_FUNC inline void pscatter_size16(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride)
+{
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16];
+ pstore<__UNPACK_TYPE__(Packet)>(a, from);
+ to[0*stride] = a[0];
+ to[1*stride] = a[1];
+ to[2*stride] = a[2];
+ to[3*stride] = a[3];
+ to[4*stride] = a[4];
+ to[5*stride] = a[5];
+ to[6*stride] = a[6];
+ to[7*stride] = a[7];
+ to[8*stride] = a[8];
+ to[9*stride] = a[9];
+ to[10*stride] = a[10];
+ to[11*stride] = a[11];
+ to[12*stride] = a[12];
+ to[13*stride] = a[13];
+ to[14*stride] = a[14];
+ to[15*stride] = a[15];
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<signed char, Packet16c>(signed char* to, const Packet16c& from, Index stride)
+{
+ pscatter_size16<Packet16c>(to, from, stride);
+}
+
+template<> EIGEN_DEVICE_FUNC inline void pscatter<unsigned char, Packet16uc>(unsigned char* to, const Packet16uc& from, Index stride)
+{
+ pscatter_size16<Packet16uc>(to, from, stride);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return pset1<Packet4f>(a) + p4f_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return pset1<Packet4i>(a) + p4i_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet8s plset<Packet8s>(const short int& a) { return pset1<Packet8s>(a) + p8s_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet8us plset<Packet8us>(const unsigned short int& a) { return pset1<Packet8us>(a) + p8us_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet16c plset<Packet16c>(const signed char& a) { return pset1<Packet16c>(a) + p16c_COUNTDOWN; }
+template<> EIGEN_STRONG_INLINE Packet16uc plset<Packet16uc>(const unsigned char& a) { return pset1<Packet16uc>(a) + p16uc_COUNTDOWN; }
+
+template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f> (const Packet4f& a, const Packet4f& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i> (const Packet4i& a, const Packet4i& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui> (const Packet4ui& a, const Packet4ui& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet8s padd<Packet8s> (const Packet8s& a, const Packet8s& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet8us padd<Packet8us> (const Packet8us& a, const Packet8us& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet16c padd<Packet16c> (const Packet16c& a, const Packet16c& b) { return a + b; }
+template<> EIGEN_STRONG_INLINE Packet16uc padd<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return a + b; }
+
+template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f> (const Packet4f& a, const Packet4f& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i> (const Packet4i& a, const Packet4i& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet8s psub<Packet8s> (const Packet8s& a, const Packet8s& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet8us psub<Packet8us> (const Packet8us& a, const Packet8us& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet16c psub<Packet16c> (const Packet16c& a, const Packet16c& b) { return a - b; }
+template<> EIGEN_STRONG_INLINE Packet16uc psub<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return a - b; }
template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return p4f_ZERO - a; }
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_ZERO - a; }
@@ -358,8 +792,13 @@ template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
-template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_madd(a,b, p4f_MZERO); }
-template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i>(const Packet4i& a, const Packet4i& b) { return a * b; }
+template<> EIGEN_STRONG_INLINE Packet4f pmul<Packet4f> (const Packet4f& a, const Packet4f& b) { return vec_madd(a,b, p4f_MZERO); }
+template<> EIGEN_STRONG_INLINE Packet4i pmul<Packet4i> (const Packet4i& a, const Packet4i& b) { return a * b; }
+template<> EIGEN_STRONG_INLINE Packet8s pmul<Packet8s> (const Packet8s& a, const Packet8s& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmul<Packet8us> (const Packet8us& a, const Packet8us& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmul<Packet16c> (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmul<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4f pdiv<Packet4f>(const Packet4f& a, const Packet4f& b)
{
@@ -387,85 +826,247 @@ template<> EIGEN_STRONG_INLINE Packet4i pdiv<Packet4i>(const Packet4i& /*a*/, co
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_madd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return a*b + c; }
+template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c) { return vec_madd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c) { return vec_madd(a,b,c); }
-template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ #ifdef __VSX__
+ // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+ Packet4f ret;
+ __asm__ ("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ #else
+ return vec_min(a, b);
+ #endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmin<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmin<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmin<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmin<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmin<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); }
+
-template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b)
+{
+ #ifdef __VSX__
+ // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
+ Packet4f ret;
+ __asm__ ("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ #else
+ return vec_max(a, b);
+ #endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s pmax<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pmax<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16c pmax<Packet16c>(const Packet16c& a, const Packet16c& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return vec_max(a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) {
+ Packet4f c = reinterpret_cast<Packet4f>(vec_cmpge(a,b));
+ return vec_nor(c,c);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return reinterpret_cast<Packet4i>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_le(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq(const Packet8s& a, const Packet8s& b) { return reinterpret_cast<Packet8s>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_le(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq(const Packet8us& a, const Packet8us& b) { return reinterpret_cast<Packet8us>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_le(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq(const Packet16c& a, const Packet16c& b) { return reinterpret_cast<Packet16c>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast<Packet16uc>(vec_cmpeq(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us pand<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf pand<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return pand<Packet8us>(a, b);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8s por<Packet8s>(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8us por<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return por<Packet8us>(a, b);
+}
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
+template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return pxor<Packet8us>(a, b);
+}
-template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); }
-template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_and(a, vec_nor(b, b)); }
+template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_andc(a, b); }
+template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_andc(a, b); }
-template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) {
+ return vec_sel(b, a, reinterpret_cast<Packet4ui>(mask));
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
+{
+ Packet4f t = vec_add(reinterpret_cast<Packet4f>(vec_or(vec_and(reinterpret_cast<Packet4ui>(a), p4ui_SIGN), p4ui_PREV0DOT5)), a);
+ Packet4f res;
+
+#ifdef __VSX__
+ __asm__("xvrspiz %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (t));
+#else
+ __asm__("vrfiz %0, %1\n\t"
+ : "=v" (res)
+ : "v" (t));
+#endif
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) { return vec_ceil(a); }
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) { return vec_floor(a); }
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a)
+{
+ Packet4f res;
-#ifdef _BIG_ENDIAN
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+ __asm__("xvrspic %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (a));
+
+ return res;
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPACK_TYPE__(Packet)* from)
{
EIGEN_DEBUG_ALIGNED_LOAD
+#ifdef _BIG_ENDIAN
Packet16uc MSQ, LSQ;
Packet16uc mask;
MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword
LSQ = vec_ld(15, (unsigned char *)from); // least significant quadword
mask = vec_lvsl(0, from); // create the permute mask
- return (Packet4f) vec_perm(MSQ, LSQ, mask); // align the data
+ //TODO: Add static_cast here
+ return (Packet) vec_perm(MSQ, LSQ, mask); // align the data
+#else
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+{
+ return ploadu_common<Packet4f>(from);
}
template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
- // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
- Packet16uc MSQ, LSQ;
- Packet16uc mask;
- MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword
- LSQ = vec_ld(15, (unsigned char *)from); // least significant quadword
- mask = vec_lvsl(0, from); // create the permute mask
- return (Packet4i) vec_perm(MSQ, LSQ, mask); // align the data
+ return ploadu_common<Packet4i>(from);
}
-#else
-// We also need ot redefine little endian loading of Packet4i/Packet4f using VSX
-template<> EIGEN_STRONG_INLINE Packet4i ploadu<Packet4i>(const int* from)
+template<> EIGEN_STRONG_INLINE Packet8s ploadu<Packet8s>(const short int* from)
{
- EIGEN_DEBUG_UNALIGNED_LOAD
- return (Packet4i) vec_vsx_ld((long)from & 15, (const int*) _EIGEN_ALIGNED_PTR(from));
+ return ploadu_common<Packet8s>(from);
}
-template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from)
+template<> EIGEN_STRONG_INLINE Packet8us ploadu<Packet8us>(const unsigned short int* from)
{
- EIGEN_DEBUG_UNALIGNED_LOAD
- return (Packet4f) vec_vsx_ld((long)from & 15, (const float*) _EIGEN_ALIGNED_PTR(from));
+ return ploadu_common<Packet8us>(from);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from)
+{
+ return ploadu_common<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+template<> EIGEN_STRONG_INLINE Packet16c ploadu<Packet16c>(const signed char* from)
+{
+ return ploadu_common<Packet16c>(from);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc ploadu<Packet16uc>(const unsigned char* from)
+{
+ return ploadu_common<Packet16uc>(from);
}
-#endif
-template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+template<typename Packet> EIGEN_STRONG_INLINE Packet ploaddup_common(const __UNPACK_TYPE__(Packet)* from)
{
- Packet4f p;
- if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet4f>(from);
- else p = ploadu<Packet4f>(from);
+ Packet p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet>(from);
+ else p = ploadu<Packet>(from);
return vec_perm(p, p, p16uc_DUPLICATE32_HI);
}
+template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
+{
+ return ploaddup_common<Packet4f>(from);
+}
template<> EIGEN_STRONG_INLINE Packet4i ploaddup<Packet4i>(const int* from)
{
- Packet4i p;
- if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet4i>(from);
- else p = ploadu<Packet4i>(from);
- return vec_perm(p, p, p16uc_DUPLICATE32_HI);
+ return ploaddup_common<Packet4i>(from);
}
-#ifdef _BIG_ENDIAN
-template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE Packet8s ploaddup<Packet8s>(const short int* from)
+{
+ Packet8s p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8s>(from);
+ else p = ploadu<Packet8s>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us ploaddup<Packet8us>(const unsigned short int* from)
+{
+ Packet8us p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8us>(from);
+ else p = ploadu<Packet8us>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s ploadquad<Packet8s>(const short int* from)
+{
+ Packet8s p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8s>(from);
+ else p = ploadu<Packet8s>(from);
+ return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us ploadquad<Packet8us>(const unsigned short int* from)
+{
+ Packet8us p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8us>(from);
+ else p = ploadu<Packet8us>(from);
+ return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadquad<Packet8bf>(const bfloat16* from)
+{
+ return ploadquad<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c ploaddup<Packet16c>(const signed char* from)
+{
+ Packet16c p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16c>(from);
+ else p = ploadu<Packet16c>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc ploaddup<Packet16uc>(const unsigned char* from)
+{
+ Packet16uc p;
+ if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16uc>(from);
+ else p = ploadu<Packet16uc>(from);
+ return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE__(Packet)* to, const Packet& from)
{
EIGEN_DEBUG_UNALIGNED_STORE
+#ifdef _BIG_ENDIAN
// Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
// Warning: not thread safe!
Packet16uc MSQ, LSQ, edges;
@@ -479,45 +1080,69 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& f
MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ)
LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ)
vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
- vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part
+ vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second
+#else
+ vec_xst(from, 0, to);
+#endif
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+{
+ pstoreu_common<Packet4f>(to, from);
}
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from)
{
- EIGEN_DEBUG_UNALIGNED_STORE
- // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
- // Warning: not thread safe!
- Packet16uc MSQ, LSQ, edges;
- Packet16uc edgeAlign, align;
-
- MSQ = vec_ld(0, (unsigned char *)to); // most significant quadword
- LSQ = vec_ld(15, (unsigned char *)to); // least significant quadword
- edgeAlign = vec_lvsl(0, to); // permute map to extract edges
- edges=vec_perm(LSQ, MSQ, edgeAlign); // extract the edges
- align = vec_lvsr( 0, to ); // permute map to misalign data
- MSQ = vec_perm(edges, (Packet16uc) from, align); // misalign the data (MSQ)
- LSQ = vec_perm((Packet16uc) from, edges, align); // misalign the data (LSQ)
- vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
- vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part
+ pstoreu_common<Packet4i>(to, from);
}
-#else
-// We also need ot redefine little endian loading of Packet4i/Packet4f using VSX
-template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from)
+template<> EIGEN_STRONG_INLINE void pstoreu<short int>(short int* to, const Packet8s& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st(from, (long)to & 15, (int*) _EIGEN_ALIGNED_PTR(to));
+ pstoreu_common<Packet8s>(to, from);
}
-template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
+template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short int* to, const Packet8us& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st(from, (long)to & 15, (float*) _EIGEN_ALIGNED_PTR(to));
+ pstoreu_common<Packet8us>(to, from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
+{
+ pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<signed char>(signed char* to, const Packet16c& from)
+{
+ pstoreu_common<Packet16c>(to, from);
+}
+template<> EIGEN_STRONG_INLINE void pstoreu<unsigned char>(unsigned char* to, const Packet16uc& from)
+{
+ pstoreu_common<Packet16uc>(to, from);
}
-#endif
template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { EIGEN_PPC_PREFETCH(addr); }
template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { EIGEN_PPC_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float EIGEN_ALIGN16 x; vec_ste(a, 0, &x); return x; }
-template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int EIGEN_ALIGN16 x; vec_ste(a, 0, &x); return x; }
+template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; }
+template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; }
+
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x;
+ vec_ste(a, 0, &x);
+ return x;
+}
+
+template<> EIGEN_STRONG_INLINE short int pfirst<Packet8s>(const Packet8s& a) {
+ return pfirst_common<Packet8s>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int pfirst<Packet8us>(const Packet8us& a) {
+ return pfirst_common<Packet8us>(a);
+}
+
+template<> EIGEN_STRONG_INLINE signed char pfirst<Packet16c>(const Packet16c& a)
+{
+ return pfirst_common<Packet16c>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char pfirst<Packet16uc>(const Packet16uc& a)
+{
+ return pfirst_common<Packet16uc>(a);
+}
template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
{
@@ -525,10 +1150,296 @@ template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a)
}
template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a)
{
- return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32)); }
+ return reinterpret_cast<Packet4i>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE32));
+}
+template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a)
+{
+ return reinterpret_cast<Packet8s>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE16));
+}
+template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a)
+{
+ return reinterpret_cast<Packet8us>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE16));
+}
+template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a)
+{
+ return vec_perm(a, a, p16uc_REVERSE8);
+}
+template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
+{
+ return vec_perm(a, a, p16uc_REVERSE8);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ return preverse<Packet8us>(a);
+}
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); }
+template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ _EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF);
+ return pand<Packet8us>(p8us_abs_mask, a);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
+{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
+{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a)
+{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
+template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
+ return reinterpret_cast<Packet4f>(r);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
+ return reinterpret_cast<Packet4f>(r);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ return vec_sr(a, p4ui_mask);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
+ return vec_sl(a, p4ui_mask);
+}
+
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
+ return vec_sl(a, p8us_mask);
+}
+template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a)
+{
+ const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
+ return vec_sr(a, p8us_mask);
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
+ return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
+}
+
+EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
+ return pand<Packet4f>(
+ reinterpret_cast<Packet4f>(bf.m_val),
+ reinterpret_cast<Packet4f>(p4ui_high_mask)
+ );
+}
+
+// Simple interleaving of bool masks, prevents true values from being
+// converted to NaNs.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) {
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
+ Packet4f bf_odd, bf_even;
+ bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd);
+ bf_even = plogical_shift_right<16>(even);
+ return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+}
+
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
+ Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
+ Packet4ui lsb = plogical_shift_right<16>(input);
+ lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
+
+ _EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu);
+ Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
+ input = padd<Packet4ui>(input, rounding_bias);
+
+ //Test NaN and Subnormal - Begin
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
+ Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
+ Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
+ Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
+ Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
+
+ Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
+ Packet4ui nan_selector = pandnot<Packet4ui>(
+ reinterpret_cast<Packet4ui>(is_max_exp),
+ reinterpret_cast<Packet4ui>(is_mant_zero)
+ );
+
+ Packet4ui subnormal_selector = pandnot<Packet4ui>(
+ reinterpret_cast<Packet4ui>(is_zero_exp),
+ reinterpret_cast<Packet4ui>(is_mant_zero)
+ );
+
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
+ input = vec_sel(input, p4ui_nan, nan_selector);
+ input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
+ //Test NaN and Subnormal - End
+
+ input = plogical_shift_right<16>(input);
+ return reinterpret_cast<Packet8us>(input);
+}
+
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
+ Packet4f bf_odd, bf_even;
+ bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
+ bf_odd = plogical_shift_left<16>(bf_odd);
+ bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
+ return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+}
+#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f op_even = OP(a_even);\
+ Packet4f op_odd = OP(a_odd);\
+ return F32ToBf16(op_even, op_odd);\
+
+#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f b_even = Bf16ToF32Even(B);\
+ Packet4f b_odd = Bf16ToF32Odd(B);\
+ Packet4f op_even = OP(a_even, b_even);\
+ Packet4f op_odd = OP(a_odd, b_odd);\
+ return F32ToBf16(op_even, op_odd);\
+
+#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \
+ Packet4f a_even = Bf16ToF32Even(A);\
+ Packet4f a_odd = Bf16ToF32Odd(A);\
+ Packet4f b_even = Bf16ToF32Even(B);\
+ Packet4f b_odd = Bf16ToF32Odd(B);\
+ Packet4f op_even = OP(a_even, b_even);\
+ Packet4f op_odd = OP(a_odd, b_odd);\
+ return F32ToBf16Bool(op_even, op_odd);\
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(padd<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmul<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pdiv<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
+ BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(psub<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psqrt<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf prsqrt<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(prsqrt<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
+ return pldexp_generic(a,exponent);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){
+ BF16_TO_F32_BINARY_OP_WRAPPER(pldexp<Packet4f>, a, exponent);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
+ return pfrexp_generic(a,exponent);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){
+ Packet4f a_even = Bf16ToF32Even(a);
+ Packet4f a_odd = Bf16ToF32Odd(a);
+ Packet4f e_even;
+ Packet4f e_odd;
+ Packet4f op_even = pfrexp<Packet4f>(a_even, e_even);
+ Packet4f op_odd = pfrexp<Packet4f>(a_odd, e_odd);
+ e = F32ToBf16(e_even, e_odd);
+ return F32ToBf16(op_even, op_odd);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psin<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcos<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pcos_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf plog<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(plog_float, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pfloor<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pceil<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(pround<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf> (const Packet8bf& a){
+ BF16_TO_F32_UNARY_OP_WRAPPER(print<Packet4f>, a);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
+ Packet4f a_even = Bf16ToF32Even(a);
+ Packet4f a_odd = Bf16ToF32Odd(a);
+ Packet4f b_even = Bf16ToF32Even(b);
+ Packet4f b_odd = Bf16ToF32Odd(b);
+ Packet4f c_even = Bf16ToF32Even(c);
+ Packet4f c_odd = Bf16ToF32Odd(c);
+ Packet4f pmadd_even = pmadd<Packet4f>(a_even, b_even, c_even);
+ Packet4f pmadd_odd = pmadd<Packet4f>(a_odd, b_odd, c_odd);
+ return F32ToBf16(pmadd_even, pmadd_odd);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER(pmax<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt_or_nan<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le<Packet4f>, a, b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
+ BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq<Packet4f>, a, b);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) {
+ return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst<Packet8us>(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploaddup<Packet8bf>(const bfloat16* from)
+{
+ return ploaddup<Packet8us>(reinterpret_cast<const unsigned short int*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
+ bfloat16 countdown[8] = { bfloat16(0), bfloat16(1), bfloat16(2), bfloat16(3),
+ bfloat16(4), bfloat16(5), bfloat16(6), bfloat16(7) };
+ return padd<Packet8bf>(pset1<Packet8bf>(a), pload<Packet8bf>(countdown));
+}
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
{
@@ -540,34 +1451,6 @@ template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
return pfirst(sum);
}
-template<> EIGEN_STRONG_INLINE Packet4f preduxp<Packet4f>(const Packet4f* vecs)
-{
- Packet4f v[4], sum[4];
-
- // It's easier and faster to transpose then add as columns
- // Check: http://www.freevec.org/function/matrix_4x4_transpose_floats for explanation
- // Do the transpose, first set of moves
- v[0] = vec_mergeh(vecs[0], vecs[2]);
- v[1] = vec_mergel(vecs[0], vecs[2]);
- v[2] = vec_mergeh(vecs[1], vecs[3]);
- v[3] = vec_mergel(vecs[1], vecs[3]);
- // Get the resulting vectors
- sum[0] = vec_mergeh(v[0], v[2]);
- sum[1] = vec_mergel(v[0], v[2]);
- sum[2] = vec_mergeh(v[1], v[3]);
- sum[3] = vec_mergel(v[1], v[3]);
-
- // Now do the summation:
- // Lines 0+1
- sum[0] = sum[0] + sum[1];
- // Lines 2+3
- sum[1] = sum[2] + sum[3];
- // Add the results
- sum[0] = sum[0] + sum[1];
-
- return sum[0];
-}
-
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
{
Packet4i sum;
@@ -580,32 +1463,69 @@ template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
return pfirst(sum);
}
-template<> EIGEN_STRONG_INLINE Packet4i preduxp<Packet4i>(const Packet4i* vecs)
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a)
{
- Packet4i v[4], sum[4];
+ float redux_even = predux<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = redux_even + redux_odd;
+ return bfloat16(f32_result);
+}
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a)
+{
+ union{
+ Packet v;
+ __UNPACK_TYPE__(Packet) n[8];
+ } vt;
+ vt.v = a;
+
+ EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] };
+ EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] };
+ Packet4i first_half = pload<Packet4i>(first_loader);
+ Packet4i second_half = pload<Packet4i>(second_loader);
+
+ return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_half) + predux(second_half));
+}
+
+template<> EIGEN_STRONG_INLINE short int predux<Packet8s>(const Packet8s& a)
+{
+ return predux_size8<Packet8s>(a);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux<Packet8us>(const Packet8us& a)
+{
+ return predux_size8<Packet8us>(a);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size16(const Packet& a)
+{
+ union{
+ Packet v;
+ __UNPACK_TYPE__(Packet) n[16];
+ } vt;
+ vt.v = a;
+
+ EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] };
+ EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] };
+ EIGEN_ALIGN16 int third_loader[4] = { vt.n[8], vt.n[9], vt.n[10], vt.n[11] };
+ EIGEN_ALIGN16 int fourth_loader[4] = { vt.n[12], vt.n[13], vt.n[14], vt.n[15] };
- // It's easier and faster to transpose then add as columns
- // Check: http://www.freevec.org/function/matrix_4x4_transpose_floats for explanation
- // Do the transpose, first set of moves
- v[0] = vec_mergeh(vecs[0], vecs[2]);
- v[1] = vec_mergel(vecs[0], vecs[2]);
- v[2] = vec_mergeh(vecs[1], vecs[3]);
- v[3] = vec_mergel(vecs[1], vecs[3]);
- // Get the resulting vectors
- sum[0] = vec_mergeh(v[0], v[2]);
- sum[1] = vec_mergel(v[0], v[2]);
- sum[2] = vec_mergeh(v[1], v[3]);
- sum[3] = vec_mergel(v[1], v[3]);
+ Packet4i first_quarter = pload<Packet4i>(first_loader);
+ Packet4i second_quarter = pload<Packet4i>(second_loader);
+ Packet4i third_quarter = pload<Packet4i>(third_loader);
+ Packet4i fourth_quarter = pload<Packet4i>(fourth_loader);
- // Now do the summation:
- // Lines 0+1
- sum[0] = sum[0] + sum[1];
- // Lines 2+3
- sum[1] = sum[2] + sum[3];
- // Add the results
- sum[0] = sum[0] + sum[1];
+ return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_quarter) + predux(second_quarter)
+ + predux(third_quarter) + predux(fourth_quarter));
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux<Packet16c>(const Packet16c& a)
+{
+ return predux_size16<Packet16c>(a);
+}
- return sum[0];
+template<> EIGEN_STRONG_INLINE unsigned char predux<Packet16uc>(const Packet16uc& a)
+{
+ return predux_size16<Packet16uc>(a);
}
// Other reduction functions:
@@ -624,97 +1544,255 @@ template<> EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a)
return aux[0] * aux[1] * aux[2] * aux[3];
}
+template<> EIGEN_STRONG_INLINE short int predux_mul<Packet8s>(const Packet8s& a)
+{
+ Packet8s pair, quad, octo;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_mul<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a)
+{
+ float redux_even = predux_mul<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_mul<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = redux_even * redux_odd;
+ return bfloat16(f32_result);
+}
+
+
+template<> EIGEN_STRONG_INLINE signed char predux_mul<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+ result = vec_mul(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_mul<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_mul(a, vec_sld(a, a, 8));
+ quad = vec_mul(pair, vec_sld(pair, pair, 4));
+ octo = vec_mul(quad, vec_sld(quad, quad, 2));
+ result = vec_mul(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
// min
-template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+template<typename Packet> EIGEN_STRONG_INLINE
+__UNPACK_TYPE__(Packet) predux_min4(const Packet& a)
{
- Packet4f b, res;
+ Packet b, res;
b = vec_min(a, vec_sld(a, a, 8));
res = vec_min(b, vec_sld(b, b, 4));
return pfirst(res);
}
+
+template<> EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a)
+{
+ return predux_min4<Packet4f>(a);
+}
+
template<> EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a)
{
- Packet4i b, res;
- b = vec_min(a, vec_sld(a, a, 8));
- res = vec_min(b, vec_sld(b, b, 4));
- return pfirst(res);
+ return predux_min4<Packet4i>(a);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a)
+{
+ float redux_even = predux_min<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_min<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = (std::min)(redux_even, redux_odd);
+ return bfloat16(f32_result);
}
+template<> EIGEN_STRONG_INLINE short int predux_min<Packet8s>(const Packet8s& a)
+{
+ Packet8s pair, quad, octo;
+
+ //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) }
+ pair = vec_min(a, vec_sld(a, a, 8));
+
+ //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) }
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_min<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) }
+ pair = vec_min(a, vec_sld(a, a, 8));
+
+ //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) }
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux_min<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_min(a, vec_sld(a, a, 8));
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ result = vec_min(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_min<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_min(a, vec_sld(a, a, 8));
+ quad = vec_min(pair, vec_sld(pair, pair, 4));
+ octo = vec_min(quad, vec_sld(quad, quad, 2));
+ result = vec_min(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
// max
-template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+template<typename Packet> EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_max4(const Packet& a)
{
- Packet4f b, res;
+ Packet b, res;
b = vec_max(a, vec_sld(a, a, 8));
res = vec_max(b, vec_sld(b, b, 4));
return pfirst(res);
}
+template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
+{
+ return predux_max4<Packet4f>(a);
+}
+
template<> EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a)
{
- Packet4i b, res;
- b = vec_max(a, vec_sld(a, a, 8));
- res = vec_max(b, vec_sld(b, b, 4));
- return pfirst(res);
+ return predux_max4<Packet4i>(a);
}
-template<int Offset>
-struct palign_impl<Offset,Packet4f>
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a)
{
- static EIGEN_STRONG_INLINE void run(Packet4f& first, const Packet4f& second)
- {
-#ifdef _BIG_ENDIAN
- switch (Offset % 4) {
- case 1:
- first = vec_sld(first, second, 4); break;
- case 2:
- first = vec_sld(first, second, 8); break;
- case 3:
- first = vec_sld(first, second, 12); break;
- }
-#else
- switch (Offset % 4) {
- case 1:
- first = vec_sld(second, first, 12); break;
- case 2:
- first = vec_sld(second, first, 8); break;
- case 3:
- first = vec_sld(second, first, 4); break;
- }
-#endif
- }
-};
+ float redux_even = predux_max<Packet4f>(Bf16ToF32Even(a));
+ float redux_odd = predux_max<Packet4f>(Bf16ToF32Odd(a));
+ float f32_result = (std::max)(redux_even, redux_odd);
+ return bfloat16(f32_result);
+}
-template<int Offset>
-struct palign_impl<Offset,Packet4i>
+template<> EIGEN_STRONG_INLINE short int predux_max<Packet8s>(const Packet8s& a)
{
- static EIGEN_STRONG_INLINE void run(Packet4i& first, const Packet4i& second)
- {
-#ifdef _BIG_ENDIAN
- switch (Offset % 4) {
- case 1:
- first = vec_sld(first, second, 4); break;
- case 2:
- first = vec_sld(first, second, 8); break;
- case 3:
- first = vec_sld(first, second, 12); break;
- }
-#else
- switch (Offset % 4) {
- case 1:
- first = vec_sld(second, first, 12); break;
- case 2:
- first = vec_sld(second, first, 8); break;
- case 3:
- first = vec_sld(second, first, 4); break;
- }
-#endif
- }
-};
+ Packet8s pair, quad, octo;
+
+ //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) }
+ pair = vec_max(a, vec_sld(a, a, 8));
+
+ //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) }
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned short int predux_max<Packet8us>(const Packet8us& a)
+{
+ Packet8us pair, quad, octo;
+
+ //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) }
+ pair = vec_max(a, vec_sld(a, a, 8));
+
+ //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) }
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+
+ //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) }
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ return pfirst(octo);
+}
+
+template<> EIGEN_STRONG_INLINE signed char predux_max<Packet16c>(const Packet16c& a)
+{
+ Packet16c pair, quad, octo, result;
+
+ pair = vec_max(a, vec_sld(a, a, 8));
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ result = vec_max(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE unsigned char predux_max<Packet16uc>(const Packet16uc& a)
+{
+ Packet16uc pair, quad, octo, result;
+
+ pair = vec_max(a, vec_sld(a, a, 8));
+ quad = vec_max(pair, vec_sld(pair, pair, 4));
+ octo = vec_max(quad, vec_sld(quad, quad, 2));
+ result = vec_max(octo, vec_sld(octo, octo, 1));
+
+ return pfirst(result);
+}
+
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x)
+{
+ return vec_any_ne(x, pzero(x));
+}
+
+template <typename T> EIGEN_DEVICE_FUNC inline void
+ptranpose_common(PacketBlock<T,4>& kernel){
+ T t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
- Packet4f t0, t1, t2, t3;
+ ptranpose_common<Packet4f>(kernel);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet4i,4>& kernel) {
+ ptranpose_common<Packet4i>(kernel);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8s,4>& kernel) {
+ Packet8s t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
@@ -726,8 +1804,8 @@ ptranspose(PacketBlock<Packet4f,4>& kernel) {
}
EIGEN_DEVICE_FUNC inline void
-ptranspose(PacketBlock<Packet4i,4>& kernel) {
- Packet4i t0, t1, t2, t3;
+ptranspose(PacketBlock<Packet8us,4>& kernel) {
+ Packet8us t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
@@ -738,18 +1816,440 @@ ptranspose(PacketBlock<Packet4i,4>& kernel) {
kernel.packet[3] = vec_mergel(t1, t3);
}
-template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ Packet8us t0, t1, t2, t3;
+
+ t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val);
+ t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val);
+ t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val);
+ t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16c,4>& kernel) {
+ Packet16c t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16uc,4>& kernel) {
+ Packet16uc t0, t1, t2, t3;
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
+ t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]);
+ t3 = vec_mergel(kernel.packet[1], kernel.packet[3]);
+ kernel.packet[0] = vec_mergeh(t0, t2);
+ kernel.packet[1] = vec_mergel(t0, t2);
+ kernel.packet[2] = vec_mergeh(t1, t3);
+ kernel.packet[3] = vec_mergel(t1, t3);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8s,8>& kernel) {
+ Packet8s v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]);
+ v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]);
+ v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]);
+ v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]);
+ v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]);
+ v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]);
+ v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]);
+ v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]);
+ sum[0] = vec_mergeh(v[0], v[4]);
+ sum[1] = vec_mergel(v[0], v[4]);
+ sum[2] = vec_mergeh(v[1], v[5]);
+ sum[3] = vec_mergel(v[1], v[5]);
+ sum[4] = vec_mergeh(v[2], v[6]);
+ sum[5] = vec_mergel(v[2], v[6]);
+ sum[6] = vec_mergeh(v[3], v[7]);
+ sum[7] = vec_mergel(v[3], v[7]);
+
+ kernel.packet[0] = vec_mergeh(sum[0], sum[4]);
+ kernel.packet[1] = vec_mergel(sum[0], sum[4]);
+ kernel.packet[2] = vec_mergeh(sum[1], sum[5]);
+ kernel.packet[3] = vec_mergel(sum[1], sum[5]);
+ kernel.packet[4] = vec_mergeh(sum[2], sum[6]);
+ kernel.packet[5] = vec_mergel(sum[2], sum[6]);
+ kernel.packet[6] = vec_mergeh(sum[3], sum[7]);
+ kernel.packet[7] = vec_mergel(sum[3], sum[7]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8us,8>& kernel) {
+ Packet8us v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]);
+ v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]);
+ v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]);
+ v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]);
+ v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]);
+ v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]);
+ v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]);
+ v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]);
+ sum[0] = vec_mergeh(v[0], v[4]);
+ sum[1] = vec_mergel(v[0], v[4]);
+ sum[2] = vec_mergeh(v[1], v[5]);
+ sum[3] = vec_mergel(v[1], v[5]);
+ sum[4] = vec_mergeh(v[2], v[6]);
+ sum[5] = vec_mergel(v[2], v[6]);
+ sum[6] = vec_mergeh(v[3], v[7]);
+ sum[7] = vec_mergel(v[3], v[7]);
+
+ kernel.packet[0] = vec_mergeh(sum[0], sum[4]);
+ kernel.packet[1] = vec_mergel(sum[0], sum[4]);
+ kernel.packet[2] = vec_mergeh(sum[1], sum[5]);
+ kernel.packet[3] = vec_mergel(sum[1], sum[5]);
+ kernel.packet[4] = vec_mergeh(sum[2], sum[6]);
+ kernel.packet[5] = vec_mergel(sum[2], sum[6]);
+ kernel.packet[6] = vec_mergeh(sum[3], sum[7]);
+ kernel.packet[7] = vec_mergel(sum[3], sum[7]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ Packet8bf v[8], sum[8];
+
+ v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val);
+ v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val);
+ v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val);
+ v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val);
+ v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val);
+ v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val);
+ v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val);
+ v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val);
+ sum[0] = vec_mergeh(v[0].m_val, v[4].m_val);
+ sum[1] = vec_mergel(v[0].m_val, v[4].m_val);
+ sum[2] = vec_mergeh(v[1].m_val, v[5].m_val);
+ sum[3] = vec_mergel(v[1].m_val, v[5].m_val);
+ sum[4] = vec_mergeh(v[2].m_val, v[6].m_val);
+ sum[5] = vec_mergel(v[2].m_val, v[6].m_val);
+ sum[6] = vec_mergeh(v[3].m_val, v[7].m_val);
+ sum[7] = vec_mergel(v[3].m_val, v[7].m_val);
+
+ kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val);
+ kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val);
+ kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val);
+ kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val);
+ kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val);
+ kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val);
+ kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val);
+ kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16c,16>& kernel) {
+ Packet16c step1[16], step2[16], step3[16];
+
+ step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]);
+ step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]);
+ step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]);
+ step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]);
+ step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]);
+ step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]);
+ step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]);
+ step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]);
+ step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]);
+ step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]);
+ step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]);
+ step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]);
+ step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]);
+ step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]);
+ step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]);
+ step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]);
+
+ step2[0] = vec_mergeh(step1[0], step1[8]);
+ step2[1] = vec_mergel(step1[0], step1[8]);
+ step2[2] = vec_mergeh(step1[1], step1[9]);
+ step2[3] = vec_mergel(step1[1], step1[9]);
+ step2[4] = vec_mergeh(step1[2], step1[10]);
+ step2[5] = vec_mergel(step1[2], step1[10]);
+ step2[6] = vec_mergeh(step1[3], step1[11]);
+ step2[7] = vec_mergel(step1[3], step1[11]);
+ step2[8] = vec_mergeh(step1[4], step1[12]);
+ step2[9] = vec_mergel(step1[4], step1[12]);
+ step2[10] = vec_mergeh(step1[5], step1[13]);
+ step2[11] = vec_mergel(step1[5], step1[13]);
+ step2[12] = vec_mergeh(step1[6], step1[14]);
+ step2[13] = vec_mergel(step1[6], step1[14]);
+ step2[14] = vec_mergeh(step1[7], step1[15]);
+ step2[15] = vec_mergel(step1[7], step1[15]);
+
+ step3[0] = vec_mergeh(step2[0], step2[8]);
+ step3[1] = vec_mergel(step2[0], step2[8]);
+ step3[2] = vec_mergeh(step2[1], step2[9]);
+ step3[3] = vec_mergel(step2[1], step2[9]);
+ step3[4] = vec_mergeh(step2[2], step2[10]);
+ step3[5] = vec_mergel(step2[2], step2[10]);
+ step3[6] = vec_mergeh(step2[3], step2[11]);
+ step3[7] = vec_mergel(step2[3], step2[11]);
+ step3[8] = vec_mergeh(step2[4], step2[12]);
+ step3[9] = vec_mergel(step2[4], step2[12]);
+ step3[10] = vec_mergeh(step2[5], step2[13]);
+ step3[11] = vec_mergel(step2[5], step2[13]);
+ step3[12] = vec_mergeh(step2[6], step2[14]);
+ step3[13] = vec_mergel(step2[6], step2[14]);
+ step3[14] = vec_mergeh(step2[7], step2[15]);
+ step3[15] = vec_mergel(step2[7], step2[15]);
+
+ kernel.packet[0] = vec_mergeh(step3[0], step3[8]);
+ kernel.packet[1] = vec_mergel(step3[0], step3[8]);
+ kernel.packet[2] = vec_mergeh(step3[1], step3[9]);
+ kernel.packet[3] = vec_mergel(step3[1], step3[9]);
+ kernel.packet[4] = vec_mergeh(step3[2], step3[10]);
+ kernel.packet[5] = vec_mergel(step3[2], step3[10]);
+ kernel.packet[6] = vec_mergeh(step3[3], step3[11]);
+ kernel.packet[7] = vec_mergel(step3[3], step3[11]);
+ kernel.packet[8] = vec_mergeh(step3[4], step3[12]);
+ kernel.packet[9] = vec_mergel(step3[4], step3[12]);
+ kernel.packet[10] = vec_mergeh(step3[5], step3[13]);
+ kernel.packet[11] = vec_mergel(step3[5], step3[13]);
+ kernel.packet[12] = vec_mergeh(step3[6], step3[14]);
+ kernel.packet[13] = vec_mergel(step3[6], step3[14]);
+ kernel.packet[14] = vec_mergeh(step3[7], step3[15]);
+ kernel.packet[15] = vec_mergel(step3[7], step3[15]);
+}
+
+EIGEN_DEVICE_FUNC inline void
+ptranspose(PacketBlock<Packet16uc,16>& kernel) {
+ Packet16uc step1[16], step2[16], step3[16];
+
+ step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]);
+ step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]);
+ step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]);
+ step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]);
+ step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]);
+ step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]);
+ step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]);
+ step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]);
+ step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]);
+ step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]);
+ step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]);
+ step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]);
+ step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]);
+ step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]);
+ step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]);
+ step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]);
+
+ step2[0] = vec_mergeh(step1[0], step1[8]);
+ step2[1] = vec_mergel(step1[0], step1[8]);
+ step2[2] = vec_mergeh(step1[1], step1[9]);
+ step2[3] = vec_mergel(step1[1], step1[9]);
+ step2[4] = vec_mergeh(step1[2], step1[10]);
+ step2[5] = vec_mergel(step1[2], step1[10]);
+ step2[6] = vec_mergeh(step1[3], step1[11]);
+ step2[7] = vec_mergel(step1[3], step1[11]);
+ step2[8] = vec_mergeh(step1[4], step1[12]);
+ step2[9] = vec_mergel(step1[4], step1[12]);
+ step2[10] = vec_mergeh(step1[5], step1[13]);
+ step2[11] = vec_mergel(step1[5], step1[13]);
+ step2[12] = vec_mergeh(step1[6], step1[14]);
+ step2[13] = vec_mergel(step1[6], step1[14]);
+ step2[14] = vec_mergeh(step1[7], step1[15]);
+ step2[15] = vec_mergel(step1[7], step1[15]);
+
+ step3[0] = vec_mergeh(step2[0], step2[8]);
+ step3[1] = vec_mergel(step2[0], step2[8]);
+ step3[2] = vec_mergeh(step2[1], step2[9]);
+ step3[3] = vec_mergel(step2[1], step2[9]);
+ step3[4] = vec_mergeh(step2[2], step2[10]);
+ step3[5] = vec_mergel(step2[2], step2[10]);
+ step3[6] = vec_mergeh(step2[3], step2[11]);
+ step3[7] = vec_mergel(step2[3], step2[11]);
+ step3[8] = vec_mergeh(step2[4], step2[12]);
+ step3[9] = vec_mergel(step2[4], step2[12]);
+ step3[10] = vec_mergeh(step2[5], step2[13]);
+ step3[11] = vec_mergel(step2[5], step2[13]);
+ step3[12] = vec_mergeh(step2[6], step2[14]);
+ step3[13] = vec_mergel(step2[6], step2[14]);
+ step3[14] = vec_mergeh(step2[7], step2[15]);
+ step3[15] = vec_mergel(step2[7], step2[15]);
+
+ kernel.packet[0] = vec_mergeh(step3[0], step3[8]);
+ kernel.packet[1] = vec_mergel(step3[0], step3[8]);
+ kernel.packet[2] = vec_mergeh(step3[1], step3[9]);
+ kernel.packet[3] = vec_mergel(step3[1], step3[9]);
+ kernel.packet[4] = vec_mergeh(step3[2], step3[10]);
+ kernel.packet[5] = vec_mergel(step3[2], step3[10]);
+ kernel.packet[6] = vec_mergeh(step3[3], step3[11]);
+ kernel.packet[7] = vec_mergel(step3[3], step3[11]);
+ kernel.packet[8] = vec_mergeh(step3[4], step3[12]);
+ kernel.packet[9] = vec_mergel(step3[4], step3[12]);
+ kernel.packet[10] = vec_mergeh(step3[5], step3[13]);
+ kernel.packet[11] = vec_mergel(step3[5], step3[13]);
+ kernel.packet[12] = vec_mergeh(step3[6], step3[14]);
+ kernel.packet[13] = vec_mergel(step3[6], step3[14]);
+ kernel.packet[14] = vec_mergeh(step3[7], step3[15]);
+ kernel.packet[15] = vec_mergel(step3[7], step3[15]);
+}
+
+template<typename Packet> EIGEN_STRONG_INLINE
+Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) {
Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
return vec_sel(elsePacket, thenPacket, mask);
}
+template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) {
+ return pblend4<Packet4i>(ifPacket, thenPacket, elsePacket);
+}
+
template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) {
- Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
- Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
+ return pblend4<Packet4f>(ifPacket, thenPacket, elsePacket);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) {
+ Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(select, p8us_ONE));
+ Packet8s result = vec_sel(elsePacket, thenPacket, mask);
+ return result;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) {
+ Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(reinterpret_cast<Packet8us>(select), p8us_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) {
+ return pblend<Packet8us>(ifPacket, thenPacket, elsePacket);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) {
+ Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
+ ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
+ ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
+ return vec_sel(elsePacket, thenPacket, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, const Packet16uc& thenPacket, const Packet16uc& elsePacket) {
+ Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
+ ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7],
+ ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
+ ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
return vec_sel(elsePacket, thenPacket, mask);
}
+template <>
+struct type_casting_traits<float, int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<int, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<bfloat16, unsigned short int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<unsigned short int, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return vec_cts(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
+ return vec_ctu(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
+ Packet4f float_even = Bf16ToF32Even(a);
+ Packet4f float_odd = Bf16ToF32Odd(a);
+ Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
+ Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
+ Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
+
+ //Check values that are bigger than USHRT_MAX (0xFFFF)
+ Packet4bi overflow_selector;
+ if(vec_any_gt(int_even, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
+ low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+ if(vec_any_gt(int_odd, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
+ low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+
+ low_odd = plogical_shift_left<16>(low_odd);
+
+ Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
+ return reinterpret_cast<Packet8us>(int_final);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
+ //short -> int -> float -> bfloat16
+ const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
+ Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
+ Packet4ui int_odd = plogical_shift_right<16>(int_cast);
+ Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
+ Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
+ return F32ToBf16(float_even, float_odd);
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
+ return reinterpret_cast<Packet4i>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
+ return reinterpret_cast<Packet4f>(a);
+}
+
+
//---------- double ----------
#ifdef __VSX__
@@ -764,9 +2264,12 @@ typedef __vector __bool long Packet2bl;
static Packet2l p2l_ONE = { 1, 1 };
static Packet2l p2l_ZERO = reinterpret_cast<Packet2l>(p4i_ZERO);
-static Packet2d p2d_ONE = { 1.0, 1.0 };
+static Packet2ul p2ul_SIGN = { 0x8000000000000000ull, 0x8000000000000000ull };
+static Packet2ul p2ul_PREV0DOT5 = { 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull };
+static Packet2d p2d_ONE = { 1.0, 1.0 };
static Packet2d p2d_ZERO = reinterpret_cast<Packet2d>(p4f_ZERO);
-static Packet2d p2d_MZERO = { -0.0, -0.0 };
+static Packet2d p2d_MZERO = { numext::bit_cast<double>(0x8000000000000000ull),
+ numext::bit_cast<double>(0x8000000000000000ull) };
#ifdef _BIG_ENDIAN
static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(p2d_ZERO), reinterpret_cast<Packet4f>(p2d_ONE), 8));
@@ -774,16 +2277,9 @@ static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_c
static Packet2d p2d_COUNTDOWN = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(p2d_ONE), reinterpret_cast<Packet4f>(p2d_ZERO), 8));
#endif
-template<int index> Packet2d vec_splat_dbl(Packet2d& a);
-
-template<> EIGEN_STRONG_INLINE Packet2d vec_splat_dbl<0>(Packet2d& a)
-{
- return reinterpret_cast<Packet2d>(vec_perm(a, a, p16uc_PSET64_HI));
-}
-
-template<> EIGEN_STRONG_INLINE Packet2d vec_splat_dbl<1>(Packet2d& a)
+template<int index> Packet2d vec_splat_dbl(Packet2d& a)
{
- return reinterpret_cast<Packet2d>(vec_perm(a, a, p16uc_PSET64_LO));
+ return vec_splat(a, index);
}
template<> struct packet_traits<double> : default_packet_traits
@@ -812,12 +2308,13 @@ template<> struct packet_traits<double> : default_packet_traits
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasNegate = 1,
HasBlend = 1
};
};
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2d half; };
inline std::ostream & operator <<(std::ostream & s, const Packet2l & v)
{
@@ -845,21 +2342,13 @@ inline std::ostream & operator <<(std::ostream & s, const Packet2d & v)
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from)
{
EIGEN_DEBUG_ALIGNED_LOAD
-#ifdef __VSX__
- return vec_vsx_ld(0, from);
-#else
- return vec_ld(0, from);
-#endif
+ return vec_xl(0, const_cast<double *>(from)); // cast needed by Clang
}
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from)
{
EIGEN_DEBUG_ALIGNED_STORE
-#ifdef __VSX__
- vec_vsx_st(from, 0, to);
-#else
- vec_st(from, 0, to);
-#endif
+ vec_xst(from, 0, to);
}
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
@@ -867,28 +2356,32 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
return v;
}
+template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from) {
+ Packet2l v = {static_cast<long long>(from), static_cast<long long>(from)};
+ return reinterpret_cast<Packet2d>(v);
+}
+
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet2d>(const double *a,
Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
- a1 = pload<Packet2d>(a);
- a0 = vec_splat_dbl<0>(a1);
- a1 = vec_splat_dbl<1>(a1);
- a3 = pload<Packet2d>(a+2);
- a2 = vec_splat_dbl<0>(a3);
- a3 = vec_splat_dbl<1>(a3);
+ //This way is faster than vec_splat (at least for doubles in Power 9)
+ a0 = pset1<Packet2d>(a[0]);
+ a1 = pset1<Packet2d>(a[1]);
+ a2 = pset1<Packet2d>(a[2]);
+ a3 = pset1<Packet2d>(a[3]);
}
template<> EIGEN_DEVICE_FUNC inline Packet2d pgather<double, Packet2d>(const double* from, Index stride)
{
- double EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 double af[2];
af[0] = from[0*stride];
af[1] = from[1*stride];
return pload<Packet2d>(af);
}
template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet2d>(double* to, const Packet2d& from, Index stride)
{
- double EIGEN_ALIGN16 af[2];
+ EIGEN_ALIGN16 double af[2];
pstore<double>(af, from);
to[0*stride] = af[0];
to[1*stride] = af[1];
@@ -910,9 +2403,29 @@ template<> EIGEN_STRONG_INLINE Packet2d pdiv<Packet2d>(const Packet2d& a, const
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); }
-template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_min(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+ // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+ Packet2d ret;
+ __asm__ ("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+ }
-template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_max(a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b)
+{
+ // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
+ Packet2d ret;
+ __asm__ ("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+ return ret;
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmple(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmplt(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return reinterpret_cast<Packet2d>(vec_cmpeq(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) {
+ Packet2d c = reinterpret_cast<Packet2d>(vec_cmpge(a,b));
+ return vec_nor(c,c);
+}
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); }
@@ -922,14 +2435,34 @@ template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const
template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); }
-template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) { return vec_round(a); }
+template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a)
+{
+ Packet2d t = vec_add(reinterpret_cast<Packet2d>(vec_or(vec_and(reinterpret_cast<Packet2ul>(a), p2ul_SIGN), p2ul_PREV0DOT5)), a);
+ Packet2d res;
+
+ __asm__("xvrdpiz %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (t));
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return vec_ceil(a); }
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) { return vec_floor(a); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a)
+{
+ Packet2d res;
+
+ __asm__("xvrdpic %x0, %x1\n\t"
+ : "=&wa" (res)
+ : "wa" (a));
+
+ return res;
+}
template<> EIGEN_STRONG_INLINE Packet2d ploadu<Packet2d>(const double* from)
{
- EIGEN_DEBUG_ALIGNED_LOAD
- return (Packet2d) vec_vsx_ld((long)from & 15, (const double*) _EIGEN_ALIGNED_PTR(from));
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return vec_xl(0, const_cast<double*>(from));
}
template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
@@ -942,13 +2475,13 @@ template<> EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from)
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from)
{
- EIGEN_DEBUG_ALIGNED_STORE
- vec_vsx_st((Packet4f)from, (long)to & 15, (float*) _EIGEN_ALIGNED_PTR(to));
+ EIGEN_DEBUG_UNALIGNED_STORE
+ vec_xst(from, 0, to);
}
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { EIGEN_PPC_PREFETCH(addr); }
-template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { double EIGEN_ALIGN16 x[2]; pstore<double>(x, a); return x[0]; }
+template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { EIGEN_ALIGN16 double x[2]; pstore<double>(x, a); return x[0]; }
template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
{
@@ -956,6 +2489,177 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
+// VSX support varies between different compilers and even different
+// versions of the same compiler. For gcc version >= 4.9.3, we can use
+// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
+// a slow version that works with older compilers.
+// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
+// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
+template<>
+inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
+#if EIGEN_GNUC_AT_LEAST(5, 4) || \
+ (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
+ return vec_cts(x, 0); // TODO: check clang version.
+#else
+ double tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2l l = { static_cast<long long>(tmp[0]),
+ static_cast<long long>(tmp[1]) };
+ return l;
+#endif
+}
+
+template<>
+inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
+ unsigned long long tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2d d = { static_cast<double>(tmp[0]),
+ static_cast<double>(tmp[1]) };
+ return d;
+}
+
+
+// Packet2l shifts.
+// For POWER8 we simply use vec_sr/l.
+//
+// Things are more complicated for POWER7. There is actually a
+// vec_xxsxdi intrinsic but it is not supported by some gcc versions.
+// So we need to shift by N % 32 and rearrage bytes.
+#ifdef __POWER8_VECTOR__
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) {
+ const Packet2ul shift = { N, N };
+ return vec_sl(a, shift);
+}
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
+ const Packet2ul shift = { N, N };
+ return vec_sr(a, shift);
+}
+
+#else
+
+// Shifts [A, B, C, D] to [B, 0, D, 0].
+// Used to implement left shifts for Packet2l.
+EIGEN_ALWAYS_INLINE Packet4i shift_even_left(const Packet4i& a) {
+ static const Packet16uc perm = {
+ 0x14, 0x15, 0x16, 0x17, 0x00, 0x01, 0x02, 0x03,
+ 0x1c, 0x1d, 0x1e, 0x1f, 0x08, 0x09, 0x0a, 0x0b };
+ #ifdef _BIG_ENDIAN
+ return vec_perm(p4i_ZERO, a, perm);
+ #else
+ return vec_perm(a, p4i_ZERO, perm);
+ #endif
+}
+
+// Shifts [A, B, C, D] to [0, A, 0, C].
+// Used to implement right shifts for Packet2l.
+EIGEN_ALWAYS_INLINE Packet4i shift_odd_right(const Packet4i& a) {
+ static const Packet16uc perm = {
+ 0x04, 0x05, 0x06, 0x07, 0x10, 0x11, 0x12, 0x13,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x18, 0x19, 0x1a, 0x1b };
+ #ifdef _BIG_ENDIAN
+ return vec_perm(p4i_ZERO, a, perm);
+ #else
+ return vec_perm(a, p4i_ZERO, perm);
+ #endif
+}
+
+template<int N, typename EnableIf = void>
+struct plogical_shift_left_impl;
+
+template<int N>
+struct plogical_shift_left_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned n = static_cast<unsigned>(N);
+ const Packet4ui shift = {n, n, n, n};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ static const unsigned m = static_cast<unsigned>(32 - N);
+ const Packet4ui shift_right = {m, m, m, m};
+ const Packet4i out_hi = vec_sl(ai, shift);
+ const Packet4i out_lo = shift_even_left(vec_sr(ai, shift_right));
+ return reinterpret_cast<Packet2l>(por<Packet4i>(out_hi, out_lo));
+ }
+};
+
+template<int N>
+struct plogical_shift_left_impl<N, typename enable_if<(N >= 32)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned m = static_cast<unsigned>(N - 32);
+ const Packet4ui shift = {m, m, m, m};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ return reinterpret_cast<Packet2l>(shift_even_left(vec_sl(ai, shift)));
+ }
+};
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) {
+ return plogical_shift_left_impl<N>::run(a);
+}
+
+template<int N, typename EnableIf = void>
+struct plogical_shift_right_impl;
+
+template<int N>
+struct plogical_shift_right_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned n = static_cast<unsigned>(N);
+ const Packet4ui shift = {n, n, n, n};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ static const unsigned m = static_cast<unsigned>(32 - N);
+ const Packet4ui shift_left = {m, m, m, m};
+ const Packet4i out_lo = vec_sr(ai, shift);
+ const Packet4i out_hi = shift_odd_right(vec_sl(ai, shift_left));
+ return reinterpret_cast<Packet2l>(por<Packet4i>(out_hi, out_lo));
+ }
+};
+
+template<int N>
+struct plogical_shift_right_impl<N, typename enable_if<(N >= 32)>::type> {
+ static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
+ static const unsigned m = static_cast<unsigned>(N - 32);
+ const Packet4ui shift = {m, m, m, m};
+ const Packet4i ai = reinterpret_cast<Packet4i>(a);
+ return reinterpret_cast<Packet2l>(shift_odd_right(vec_sr(ai, shift)));
+ }
+};
+
+template<int N>
+EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
+ return plogical_shift_right_impl<N>::run(a);
+}
+#endif
+
+template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
+ // Clamp exponent to [-2099, 2099]
+ const Packet2d max_exponent = pset1<Packet2d>(2099.0);
+ const Packet2l e = pcast<Packet2d, Packet2l>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+
+ // Split 2^e into four factors and multiply:
+ const Packet2l bias = { 1023, 1023 };
+ Packet2l b = plogical_shift_right<2>(e); // floor(e/4)
+ Packet2d c = reinterpret_cast<Packet2d>(plogical_shift_left<52>(b + bias));
+ Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = reinterpret_cast<Packet2d>(plogical_shift_left<52>(b + bias)); // 2^(e - 3b)
+ out = pmul(out, c); // a * 2^e
+ return out;
+}
+
+
+// Extract exponent without existence of Packet2l.
+template<>
+EIGEN_STRONG_INLINE
+Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
+ return pcast<Packet2l, Packet2d>(plogical_shift_right<52>(reinterpret_cast<Packet2l>(pabs(a))));
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d> (const Packet2d& a, Packet2d& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
{
Packet2d b, sum;
@@ -964,20 +2668,6 @@ template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
return pfirst<Packet2d>(sum);
}
-template<> EIGEN_STRONG_INLINE Packet2d preduxp<Packet2d>(const Packet2d* vecs)
-{
- Packet2d v[2], sum;
- v[0] = vecs[0] + reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(vecs[0]), reinterpret_cast<Packet4f>(vecs[0]), 8));
- v[1] = vecs[1] + reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(vecs[1]), reinterpret_cast<Packet4f>(vecs[1]), 8));
-
-#ifdef _BIG_ENDIAN
- sum = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(v[0]), reinterpret_cast<Packet4f>(v[1]), 8));
-#else
- sum = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4f>(v[1]), reinterpret_cast<Packet4f>(v[0]), 8));
-#endif
-
- return sum;
-}
// Other reduction functions:
// mul
template<> EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a)
@@ -997,20 +2687,6 @@ template<> EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a)
return pfirst(pmax(a, reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(a), reinterpret_cast<Packet4ui>(a), 8))));
}
-template<int Offset>
-struct palign_impl<Offset,Packet2d>
-{
- static EIGEN_STRONG_INLINE void run(Packet2d& first, const Packet2d& second)
- {
- if (Offset == 1)
-#ifdef _BIG_ENDIAN
- first = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(first), reinterpret_cast<Packet4ui>(second), 8));
-#else
- first = reinterpret_cast<Packet2d>(vec_sld(reinterpret_cast<Packet4ui>(second), reinterpret_cast<Packet4ui>(first), 8));
-#endif
- }
-};
-
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2d,2>& kernel) {
Packet2d t0, t1;
@@ -1022,9 +2698,11 @@ ptranspose(PacketBlock<Packet2d,2>& kernel) {
template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
Packet2l select = { ifPacket.select[0], ifPacket.select[1] };
- Packet2bl mask = vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE));
+ Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
return vec_sel(elsePacket, thenPacket, mask);
}
+
+
#endif // __VSX__
} // end namespace internal