diff options
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/PacketMath.h | 1989 |
1 files changed, 1488 insertions, 501 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index f6500a16e..34d49ab66 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -19,10 +19,10 @@ namespace internal { #endif #ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS -#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*)) +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 #endif -#ifdef __FMA__ +#ifdef EIGEN_VECTORIZE_FMA #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD #endif @@ -31,6 +31,8 @@ namespace internal { typedef __m512 Packet16f; typedef __m512i Packet16i; typedef __m512d Packet8d; +typedef eigen_packet_wrapper<__m256i, 1> Packet16h; +typedef eigen_packet_wrapper<__m256i, 2> Packet16bf; template <> struct is_arithmetic<__m512> { @@ -45,6 +47,51 @@ struct is_arithmetic<__m512d> { enum { value = true }; }; +template<> struct is_arithmetic<Packet16h> { enum { value = true }; }; + +template <> +struct packet_traits<half> : default_packet_traits { + typedef Packet16h type; + // There is no half-size packet for Packet16h. + typedef Packet16h half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 1, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasBessel = 1, + HasNdtri = 1 + }; +}; + template<> struct packet_traits<float> : default_packet_traits { typedef Packet16f type; @@ -54,15 +101,32 @@ template<> struct packet_traits<float> : default_packet_traits AlignedOnScalar = 1, size = 16, HasHalfPacket = 1, -#if EIGEN_GNUC_AT_LEAST(5, 3) -#ifdef EIGEN_VECTORIZE_AVX512DQ + + HasAbs = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasBlend = 0, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) HasLog = 1, -#endif + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasBessel = 1, HasExp = 1, - HasSqrt = 1, - HasRsqrt = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, #endif - HasDiv = 1 + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; template<> struct packet_traits<double> : default_packet_traits @@ -74,11 +138,18 @@ template<> struct packet_traits<double> : default_packet_traits AlignedOnScalar = 1, size = 8, HasHalfPacket = 1, -#if EIGEN_GNUC_AT_LEAST(5, 3) - HasSqrt = 1, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) + HasLog = 1, + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, HasRsqrt = EIGEN_FAST_MATH, #endif - HasDiv = 1 + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -98,19 +169,28 @@ template <> struct unpacket_traits<Packet16f> { typedef float type; typedef Packet8f half; - enum { size = 16, alignment=Aligned64 }; + typedef Packet16i integer_packet; + typedef uint16_t mask_t; + enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true }; }; template <> struct unpacket_traits<Packet8d> { typedef double type; typedef Packet4d half; - enum { size = 8, alignment=Aligned64 }; + enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false }; }; template <> struct unpacket_traits<Packet16i> { typedef int type; typedef Packet8i half; - enum { size = 16, alignment=Aligned64 }; + enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false, masked_store_available=false }; +}; + +template<> +struct unpacket_traits<Packet16h> { + typedef Eigen::half type; + typedef Packet8h half; + enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; template <> @@ -127,12 +207,39 @@ EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) { } template <> +EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) { + return _mm512_castsi512_ps(_mm512_set1_epi32(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) { + return _mm512_castsi512_pd(_mm512_set1_epi64(from)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } + +template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) { + return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); +} +template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) { + return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1); +} +template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) { + return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, + 0, 0, -1, -1, 0, 0, -1, -1)); +} + +template <> EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) { return _mm512_broadcastss_ps(_mm_load_ps1(from)); } template <> EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) { - return _mm512_broadcastsd_pd(_mm_load_pd1(from)); + return _mm512_set1_pd(*from); } template <> @@ -158,6 +265,11 @@ EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, const Packet8d& b) { return _mm512_add_pd(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a, + const Packet16i& b) { + return _mm512_add_epi32(a, b); +} template <> EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a, @@ -169,6 +281,11 @@ EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a, const Packet8d& b) { return _mm512_sub_pd(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a, + const Packet16i& b) { + return _mm512_sub_epi32(a, b); +} template <> EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) { @@ -202,6 +319,11 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, const Packet8d& b) { return _mm512_mul_pd(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, + const Packet16i& b) { + return _mm512_mullo_epi32(a, b); +} template <> EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a, @@ -214,7 +336,7 @@ EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a, return _mm512_div_pd(a, b); } -#ifdef __FMA__ +#ifdef EIGEN_VECTORIZE_FMA template <> EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b, const Packet16f& c) { @@ -228,51 +350,216 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, #endif template <> +EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask, + const Packet16f& a, + const Packet16f& b) { + __mmask16 mask16 = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mask16, a, b); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, + const Packet8d& a, + const Packet8d& b) { + __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), + _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mask8, a, b); +} + +template <> EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a, const Packet16f& b) { - return _mm512_min_ps(a, b); + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm512_min_ps(b, a); } template <> EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a, const Packet8d& b) { - return _mm512_min_pd(a, b); + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm512_min_pd(b, a); } template <> EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a, const Packet16f& b) { - return _mm512_max_ps(a, b); + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm512_max_ps(b, a); } template <> EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a, const Packet8d& b) { - return _mm512_max_pd(a, b); + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm512_max_pd(b, a); } -template <> -EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a, - const Packet16f& b) { +// Add specializations for min/max with prescribed NaN progation. +template<> +EIGEN_STRONG_INLINE Packet16f pmin<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_numbers(a, b, pmin<Packet16f>); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmin<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_numbers(a, b, pmin<Packet8d>); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmax<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_numbers(a, b, pmax<Packet16f>); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmax<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_numbers(a, b, pmax<Packet8d>); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmin<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_nan(a, b, pmin<Packet16f>); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmin<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_nan(a, b, pmin<Packet8d>); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmax<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_nan(a, b, pmax<Packet16f>); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmax<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_nan(a, b, pmax<Packet8d>); +} + + #ifdef EIGEN_VECTORIZE_AVX512DQ - return _mm512_and_ps(a, b); +template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); } +template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); } +EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); } #else - Packet16f res = _mm512_undefined_ps(); - Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0); - Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0); - res = _mm512_insertf32x4(res, _mm_and_ps(lane0_a, lane0_b), 0); +// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 +template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { + return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_)); +} - Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1); - Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1); - res = _mm512_insertf32x4(res, _mm_and_ps(lane1_a, lane1_b), 1); +// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512 +template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { + return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_)); +} + +EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { + return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), + _mm256_castps_si256(b),1)); +} +#endif + +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 32x16 to 16x16. +EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { + // Split data into small pieces and handle with AVX instructions + // to guarantee internal order of vector. + // Operation: + // dst[15:0] := Saturate16(rf[31:0]) + // dst[31:16] := Saturate16(rf[63:32]) + // ... + // dst[255:240] := Saturate16(rf[255:224]) + __m256i lo = _mm256_castps_si256(extract256<0>(rf)); + __m256i hi = _mm256_castps_si256(extract256<1>(rf)); + __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), + _mm256_extractf128_si256(lo, 1)); + __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), + _mm256_extractf128_si256(hi, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} +template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { + __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _CMP_EQ_OQ); + return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); +} - Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2); - Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2); - res = _mm512_insertf32x4(res, _mm_and_ps(lane2_a, lane2_b), 2); - Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3); - Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3); - res = _mm512_insertf32x4(res, _mm_and_ps(lane3_a, lane3_b), 3); +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} - return res; +template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); } + +template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); } + +template <> +EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) { + return _mm512_set1_epi32(0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE Packet16f ptrue<Packet16f>(const Packet16f& a) { + return _mm512_castsi512_ps(ptrue<Packet16i>(_mm512_castps_si512(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8d ptrue<Packet8d>(const Packet8d& a) { + return _mm512_castsi512_pd(ptrue<Packet16i>(_mm512_castpd_si512(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a, + const Packet16i& b) { + return _mm512_and_si512(a,b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a, + const Packet16f& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_and_ps(a, b); +#else + return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b))); #endif } template <> @@ -288,35 +575,21 @@ EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a, Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); - res = _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1); - - return res; + return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1); #endif } + +template <> +EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) { + return _mm512_or_si512(a, b); +} + template <> -EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, - const Packet16f& b) { +EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ return _mm512_or_ps(a, b); #else - Packet16f res = _mm512_undefined_ps(); - Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0); - Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0); - res = _mm512_insertf32x4(res, _mm_or_ps(lane0_a, lane0_b), 0); - - Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1); - Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1); - res = _mm512_insertf32x4(res, _mm_or_ps(lane1_a, lane1_b), 1); - - Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2); - Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2); - res = _mm512_insertf32x4(res, _mm_or_ps(lane2_a, lane2_b), 2); - - Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3); - Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3); - res = _mm512_insertf32x4(res, _mm_or_ps(lane3_a, lane3_b), 3); - - return res; + return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b))); #endif } @@ -326,107 +599,80 @@ EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a, #ifdef EIGEN_VECTORIZE_AVX512DQ return _mm512_or_pd(a, b); #else - Packet8d res = _mm512_undefined_pd(); - Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0); - Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0); - res = _mm512_insertf64x4(res, _mm256_or_pd(lane0_a, lane0_b), 0); - - Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); - Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); - res = _mm512_insertf64x4(res, _mm256_or_pd(lane1_a, lane1_b), 1); - - return res; + return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); #endif } template <> -EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, - const Packet16f& b) { +EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) { + return _mm512_xor_si512(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ return _mm512_xor_ps(a, b); #else - Packet16f res = _mm512_undefined_ps(); - Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0); - Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0); - res = _mm512_insertf32x4(res, _mm_xor_ps(lane0_a, lane0_b), 0); - - Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1); - Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1); - res = _mm512_insertf32x4(res, _mm_xor_ps(lane1_a, lane1_b), 1); - - Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2); - Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2); - res = _mm512_insertf32x4(res, _mm_xor_ps(lane2_a, lane2_b), 2); - - Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3); - Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3); - res = _mm512_insertf32x4(res, _mm_xor_ps(lane3_a, lane3_b), 3); - - return res; + return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b))); #endif } + template <> -EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, - const Packet8d& b) { +EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ return _mm512_xor_pd(a, b); #else - Packet8d res = _mm512_undefined_pd(); - Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0); - Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0); - res = _mm512_insertf64x4(res, _mm256_xor_pd(lane0_a, lane0_b), 0); - - Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); - Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); - res = _mm512_insertf64x4(res, _mm256_xor_pd(lane1_a, lane1_b), 1); - - return res; + return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); #endif } template <> -EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, - const Packet16f& b) { +EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) { + return _mm512_andnot_si512(b, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ - return _mm512_andnot_ps(a, b); + return _mm512_andnot_ps(b, a); #else - Packet16f res = _mm512_undefined_ps(); - Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0); - Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0); - res = _mm512_insertf32x4(res, _mm_andnot_ps(lane0_a, lane0_b), 0); - - Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1); - Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1); - res = _mm512_insertf32x4(res, _mm_andnot_ps(lane1_a, lane1_b), 1); - - Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2); - Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2); - res = _mm512_insertf32x4(res, _mm_andnot_ps(lane2_a, lane2_b), 2); - - Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3); - Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3); - res = _mm512_insertf32x4(res, _mm_andnot_ps(lane3_a, lane3_b), 3); - - return res; + return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b))); #endif } template <> -EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a, - const Packet8d& b) { +EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ - return _mm512_andnot_pd(a, b); + return _mm512_andnot_pd(b, a); #else - Packet8d res = _mm512_undefined_pd(); - Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0); - Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0); - res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane0_a, lane0_b), 0); + return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); +#endif +} - Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); - Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); - res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane1_a, lane1_b), 1); +template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a) +{ + // Work-around for default std::round rounding mode. + const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u)); + const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu)); + return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a) +{ + // Work-around for default std::round rounding mode. + const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull)); + const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull)); + return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} - return res; -#endif +template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { + return _mm512_srai_epi32(a, N); +} + +template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) { + return _mm512_srli_epi32(a, N); +} + +template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) { + return _mm512_slli_epi32(a, N); } template <> @@ -457,79 +703,65 @@ EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) { reinterpret_cast<const __m512i*>(from)); } +template <> +EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) { + __mmask16 mask = static_cast<__mmask16>(umask); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from); +} + // Loads 8 floats from memory a returns the packet // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} template <> EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) { - Packet8f lane0 = _mm256_broadcast_ps((const __m128*)(const void*)from); - // mimic an "inplace" permutation of the lower 128bits using a blend - lane0 = _mm256_blend_ps( - lane0, _mm256_castps128_ps256(_mm_permute_ps( - _mm256_castps256_ps128(lane0), _MM_SHUFFLE(1, 0, 1, 0))), - 15); - // then we can perform a consistent permutation on the global register to get - // everything in shape: - lane0 = _mm256_permute_ps(lane0, _MM_SHUFFLE(3, 3, 2, 2)); - - Packet8f lane1 = _mm256_broadcast_ps((const __m128*)(const void*)(from + 4)); - // mimic an "inplace" permutation of the lower 128bits using a blend - lane1 = _mm256_blend_ps( - lane1, _mm256_castps128_ps256(_mm_permute_ps( - _mm256_castps256_ps128(lane1), _MM_SHUFFLE(1, 0, 1, 0))), - 15); - // then we can perform a consistent permutation on the global register to get - // everything in shape: - lane1 = _mm256_permute_ps(lane1, _MM_SHUFFLE(3, 3, 2, 2)); + // an unaligned load is required here as there is no requirement + // on the alignment of input pointer 'from' + __m256i low_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); + __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); + __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); + return pairs; +} #ifdef EIGEN_VECTORIZE_AVX512DQ - Packet16f res = _mm512_undefined_ps(); - return _mm512_insertf32x8(res, lane0, 0); - return _mm512_insertf32x8(res, lane1, 1); - return res; -#else - Packet16f res = _mm512_undefined_ps(); - res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 0), 0); - res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 1), 1); - res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 0), 2); - res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 1), 3); - return res; -#endif -} +// FIXME: this does not look optimal, better load a Packet4d and shuffle... // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, // a3} template <> EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) { - Packet4d lane0 = _mm256_broadcast_pd((const __m128d*)(const void*)from); - lane0 = _mm256_permute_pd(lane0, 3 << 2); - - Packet4d lane1 = _mm256_broadcast_pd((const __m128d*)(const void*)(from + 2)); - lane1 = _mm256_permute_pd(lane1, 3 << 2); - - Packet8d res = _mm512_undefined_pd(); - res = _mm512_insertf64x4(res, lane0, 0); - return _mm512_insertf64x4(res, lane1, 1); + __m512d x = _mm512_setzero_pd(); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3); + return x; +} +#else +template <> +EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) { + __m512d x = _mm512_setzero_pd(); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3)); + return x; } +#endif // Loads 4 floats from memory a returns the packet // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} template <> EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) { - Packet16f tmp = _mm512_undefined_ps(); - tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from), 0); - tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 1), 1); - tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 2), 2); - tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 3), 3); - return tmp; + Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from)); + const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0); + return _mm512_permutexvar_ps(scatter_mask, tmp); } + // Loads 2 doubles from memory a returns the packet // {a0, a0 a0, a0, a1, a1, a1, a1} template <> EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) { - Packet8d tmp = _mm512_undefined_pd(); - Packet2d tmp0 = _mm_load_pd1(from); - Packet2d tmp1 = _mm_load_pd1(from + 1); - Packet4d lane0 = _mm256_broadcastsd_pd(tmp0); - Packet4d lane1 = _mm256_broadcastsd_pd(tmp1); + __m256d lane0 = _mm256_set1_pd(*from); + __m256d lane1 = _mm256_set1_pd(*(from+1)); + __m512d tmp = _mm512_undefined_pd(); tmp = _mm512_insertf64x4(tmp, lane0, 0); return _mm512_insertf64x4(tmp, lane1, 1); } @@ -561,11 +793,16 @@ EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( reinterpret_cast<__m512i*>(to), from); } +template <> +EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16_t umask) { + __mmask16 mask = static_cast<__mmask16>(umask); + EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from); +} template <> EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, Index stride) { - Packet16i stride_vector = _mm512_set1_epi32(stride); + Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride)); Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); @@ -575,7 +812,7 @@ EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, template <> EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from, Index stride) { - Packet8i stride_vector = _mm256_set1_epi32(stride); + Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride)); Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); @@ -586,7 +823,7 @@ template <> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, const Packet16f& from, Index stride) { - Packet16i stride_vector = _mm512_set1_epi32(stride); + Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride)); Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); @@ -596,7 +833,7 @@ template <> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, const Packet8d& from, Index stride) { - Packet8i stride_vector = _mm256_set1_epi32(stride); + Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride)); Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); _mm512_i32scatter_pd(to, indices, from, 8); @@ -618,9 +855,9 @@ EIGEN_STRONG_INLINE void pstore1<Packet16i>(int* to, const int& a) { pstore(to, pa); } -template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); } -template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); } -template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((const char*)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } template <> EIGEN_STRONG_INLINE float pfirst<Packet16f>(const Packet16f& a) { @@ -648,20 +885,73 @@ template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) { // _mm512_abs_ps intrinsic not found, so hack around it - return (__m512)_mm512_and_si512((__m512i)a, _mm512_set1_epi32(0x7fffffff)); + return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff))); } template <> EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { // _mm512_abs_ps intrinsic not found, so hack around it - return (__m512d)_mm512_and_si512((__m512i)a, - _mm512_set1_epi64(0x7fffffffffffffff)); + return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), + _mm512_set1_epi64(0x7fffffffffffffff))); +} + +template<> +EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){ + return pfrexp_generic(a, exponent); +} + +// Extract exponent without existence of Packet8l. +template<> +EIGEN_STRONG_INLINE +Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) { + const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull)); + #ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)); + #else + return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52))); + #endif +} + +template<> +EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) { + return pfrexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) { + return pldexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) { + // Clamp exponent to [-2099, 2099] + const Packet8d max_exponent = pset1<Packet8d>(2099.0); + const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + + // Split 2^e into four factors and multiply. + const Packet8i bias = pset1<Packet8i>(1023); + Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4) + + // 2^b + const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); + Packet8i lo = _mm256_slli_epi64(hi, 52); + hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52); + Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + + // 2^(e - 3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); + lo = _mm256_slli_epi64(hi, 52); + hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52); + c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + out = pmul(out, c); // a * 2^e + return out; } #ifdef EIGEN_VECTORIZE_AVX512DQ // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ - __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0) __m256 OUTPUT##_1 = \ - _mm512_extractf32x8_ps(INPUT, 1) + __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \ + __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1) #else #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ __m256 OUTPUT##_0 = _mm256_insertf128_ps( \ @@ -674,258 +964,64 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { #ifdef EIGEN_VECTORIZE_AVX512DQ #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ - OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTA, 0); \ - OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTB, 1); + OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1); #else #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ + OUTPUT = _mm512_undefined_ps(); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3); #endif -template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f* -vecs) -{ - EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0); - EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1); - EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2); - EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3); - EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4); - EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5); - EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6); - EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7); - EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8); - EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9); - EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10); - EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11); - EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12); - EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13); - EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14); - EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15); - - __m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0); - __m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0); - __m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0); - __m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0); - - __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1); - __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2); - __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3); - __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4); - - __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); - __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); - __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); - __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); - - __m256 sum1 = _mm256_add_ps(perm1, hsum5); - __m256 sum2 = _mm256_add_ps(perm2, hsum6); - __m256 sum3 = _mm256_add_ps(perm3, hsum7); - __m256 sum4 = _mm256_add_ps(perm4, hsum8); - - __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); - __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); - - __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0); - - hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1); - hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1); - hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1); - hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1); - - hsum5 = _mm256_hadd_ps(hsum1, hsum1); - hsum6 = _mm256_hadd_ps(hsum2, hsum2); - hsum7 = _mm256_hadd_ps(hsum3, hsum3); - hsum8 = _mm256_hadd_ps(hsum4, hsum4); - - perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); - perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); - perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); - perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); - - sum1 = _mm256_add_ps(perm1, hsum5); - sum2 = _mm256_add_ps(perm2, hsum6); - sum3 = _mm256_add_ps(perm3, hsum7); - sum4 = _mm256_add_ps(perm4, hsum8); - - blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); - blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); - - final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0)); - - hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0); - hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0); - hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0); - hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0); - - hsum5 = _mm256_hadd_ps(hsum1, hsum1); - hsum6 = _mm256_hadd_ps(hsum2, hsum2); - hsum7 = _mm256_hadd_ps(hsum3, hsum3); - hsum8 = _mm256_hadd_ps(hsum4, hsum4); - - perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); - perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); - perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); - perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); - - sum1 = _mm256_add_ps(perm1, hsum5); - sum2 = _mm256_add_ps(perm2, hsum6); - sum3 = _mm256_add_ps(perm3, hsum7); - sum4 = _mm256_add_ps(perm4, hsum8); - - blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); - blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); - - __m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0); - - hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1); - hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1); - hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1); - hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1); - - hsum5 = _mm256_hadd_ps(hsum1, hsum1); - hsum6 = _mm256_hadd_ps(hsum2, hsum2); - hsum7 = _mm256_hadd_ps(hsum3, hsum3); - hsum8 = _mm256_hadd_ps(hsum4, hsum4); - - perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); - perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); - perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); - perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); - - sum1 = _mm256_add_ps(perm1, hsum5); - sum2 = _mm256_add_ps(perm2, hsum6); - sum3 = _mm256_add_ps(perm3, hsum7); - sum4 = _mm256_add_ps(perm4, hsum8); - - blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); - blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); - - final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0)); - - __m512 final_output; - - EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1); - return final_output; -} - -template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs) -{ - Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0); - Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1); - - Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0); - Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1); - - Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0); - Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1); - - Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0); - Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1); - - Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0); - Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1); - - Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0); - Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1); - - Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0); - Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1); - - Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0); - Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1); - - Packet4d tmp0, tmp1; - - tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0); - tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); - - tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0); - tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); - - __m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC); - - tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1); - tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); - - tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1); - tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); - - final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC)); - - tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0); - tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); - - tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0); - tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); - - __m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC); - - tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1); - tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); - - tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1); - tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); - - final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC)); - - __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0); - - return _mm512_insertf64x4(final_output, final_1, 1); -} template <> EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) { - //#ifdef EIGEN_VECTORIZE_AVX512DQ -#if 0 - Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); - Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); - Packet8f sum = padd(lane0, lane1); - Packet8f tmp0 = _mm256_hadd_ps(sum, _mm256_permute2f128_ps(a, a, 1)); - tmp0 = _mm256_hadd_ps(tmp0, tmp0); - return pfirst(_mm256_hadd_ps(tmp0, tmp0)); +#ifdef EIGEN_VECTORIZE_AVX512DQ + __m256 lane0 = _mm512_extractf32x8_ps(a, 0); + __m256 lane1 = _mm512_extractf32x8_ps(a, 1); + Packet8f x = _mm256_add_ps(lane0, lane1); + return predux<Packet8f>(x); #else - Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); - Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); - Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); - Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); - Packet4f sum = padd(padd(lane0, lane1), padd(lane2, lane3)); + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3)); sum = _mm_hadd_ps(sum, sum); sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1)); - return pfirst(sum); + return _mm_cvtss_f32(sum); #endif } template <> EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) { - Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); - Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); - Packet4d sum = padd(lane0, lane1); - Packet4d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); - return pfirst(_mm256_hadd_pd(tmp0, tmp0)); + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d sum = _mm256_add_pd(lane0, lane1); + __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); + return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0))); } template <> -EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) { +EIGEN_STRONG_INLINE Packet8f predux_half_dowto4<Packet16f>(const Packet16f& a) { #ifdef EIGEN_VECTORIZE_AVX512DQ - Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); - Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); - return padd(lane0, lane1); + __m256 lane0 = _mm512_extractf32x8_ps(a, 0); + __m256 lane1 = _mm512_extractf32x8_ps(a, 1); + return _mm256_add_ps(lane0, lane1); #else - Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); - Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); - Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); - Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); - Packet4f sum0 = padd(lane0, lane2); - Packet4f sum1 = padd(lane1, lane3); + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 sum0 = _mm_add_ps(lane0, lane2); + __m128 sum1 = _mm_add_ps(lane1, lane3); return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1); #endif } template <> -EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) { - Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); - Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); - Packet4d res = padd(lane0, lane1); - return res; +EIGEN_STRONG_INLINE Packet4d predux_half_dowto4<Packet8d>(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + return _mm256_add_pd(lane0, lane1); } template <> @@ -939,108 +1035,70 @@ EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) { res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); #else - Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); - Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); - Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); - Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); - Packet4f res = pmul(pmul(lane0, lane1), pmul(lane2, lane3)); + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3)); res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); #endif } template <> EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) { - Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); - Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); - Packet4d res = pmul(lane0, lane1); + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = pmul(lane0, lane1); res = pmul(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1))); } template <> EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) { - Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); - Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); - Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); - Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); - Packet4f res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3)); + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3)); res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); } template <> EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) { - Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); - Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); - Packet4d res = _mm256_min_pd(lane0, lane1); + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = _mm256_min_pd(lane0, lane1); res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1))); } template <> EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) { - Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); - Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); - Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); - Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); - Packet4f res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3)); + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3)); res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); } + template <> EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) { - Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); - Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); - Packet4d res = _mm256_max_pd(lane0, lane1); + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = _mm256_max_pd(lane0, lane1); res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1))); } -template <int Offset> -struct palign_impl<Offset, Packet16f> { - static EIGEN_STRONG_INLINE void run(Packet16f& first, - const Packet16f& second) { - if (Offset != 0) { - __m512i first_idx = _mm512_set_epi32( - Offset + 15, Offset + 14, Offset + 13, Offset + 12, Offset + 11, - Offset + 10, Offset + 9, Offset + 8, Offset + 7, Offset + 6, - Offset + 5, Offset + 4, Offset + 3, Offset + 2, Offset + 1, Offset); - - __m512i second_idx = - _mm512_set_epi32(Offset - 1, Offset - 2, Offset - 3, Offset - 4, - Offset - 5, Offset - 6, Offset - 7, Offset - 8, - Offset - 9, Offset - 10, Offset - 11, Offset - 12, - Offset - 13, Offset - 14, Offset - 15, Offset - 16); - - unsigned short mask = 0xFFFF; - mask <<= (16 - Offset); - - first = _mm512_permutexvar_ps(first_idx, first); - Packet16f tmp = _mm512_permutexvar_ps(second_idx, second); - first = _mm512_mask_blend_ps(mask, first, tmp); - } - } -}; -template <int Offset> -struct palign_impl<Offset, Packet8d> { - static EIGEN_STRONG_INLINE void run(Packet8d& first, const Packet8d& second) { - if (Offset != 0) { - __m512i first_idx = _mm512_set_epi32( - 0, Offset + 7, 0, Offset + 6, 0, Offset + 5, 0, Offset + 4, 0, - Offset + 3, 0, Offset + 2, 0, Offset + 1, 0, Offset); - - __m512i second_idx = _mm512_set_epi32( - 0, Offset - 1, 0, Offset - 2, 0, Offset - 3, 0, Offset - 4, 0, - Offset - 5, 0, Offset - 6, 0, Offset - 7, 0, Offset - 8); - - unsigned char mask = 0xFF; - mask <<= (8 - Offset); - - first = _mm512_permutexvar_pd(first_idx, first); - Packet8d tmp = _mm512_permutexvar_pd(second_idx, second); - first = _mm512_mask_blend_pd(mask, first, tmp); - } - } -}; +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x) +{ + Packet16i xi = _mm512_castps_si512(x); + __mmask16 tmp = _mm512_test_epi32_mask(xi,xi); + return !_mm512_kortestz(tmp,tmp); +} + #define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \ @@ -1302,11 +1360,940 @@ EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/, return Packet16f(); } template <> -EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& /*ifPacket*/, - const Packet8d& /*thenPacket*/, - const Packet8d& /*elsePacket*/) { - assert(false && "To be implemented"); - return Packet8d(); +EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, + const Packet8d& thenPacket, + const Packet8d& elsePacket) { + __mmask8 m = (ifPacket.select[0] ) + | (ifPacket.select[1]<<1) + | (ifPacket.select[2]<<2) + | (ifPacket.select[3]<<3) + | (ifPacket.select[4]<<4) + | (ifPacket.select[5]<<5) + | (ifPacket.select[6]<<6) + | (ifPacket.select[7]<<7); + return _mm512_mask_blend_pd(m, elsePacket, thenPacket); +} + +// Packet math for Eigen::half +template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) { + return _mm256_set1_epi16(from.x); +} + +template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) { + return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) { + return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); +} + +template<> EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) { + return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) { + // (void*) -> workaround clang warning: + // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32 + _mm256_store_si256((__m256i*)(void*)to, from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) { + // (void*) -> workaround clang warning: + // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32 + _mm256_storeu_si256((__m256i*)(void*)to, from); +} + +template<> EIGEN_STRONG_INLINE Packet16h +ploaddup<Packet16h>(const Eigen::half* from) { + unsigned short a = from[0].x; + unsigned short b = from[1].x; + unsigned short c = from[2].x; + unsigned short d = from[3].x; + unsigned short e = from[4].x; + unsigned short f = from[5].x; + unsigned short g = from[6].x; + unsigned short h = from[7].x; + return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet16h +ploadquad(const Eigen::half* from) { + unsigned short a = from[0].x; + unsigned short b = from[1].x; + unsigned short c = from[2].x; + unsigned short d = from[3].x; + return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); +} + +EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm512_cvtph_ps(a); +#else + EIGEN_ALIGN64 half aux[16]; + pstore(aux, a); + float f0(aux[0]); + float f1(aux[1]); + float f2(aux[2]); + float f3(aux[3]); + float f4(aux[4]); + float f5(aux[5]); + float f6(aux[6]); + float f7(aux[7]); + float f8(aux[8]); + float f9(aux[9]); + float fa(aux[10]); + float fb(aux[11]); + float fc(aux[12]); + float fd(aux[13]); + float fe(aux[14]); + float ff(aux[15]); + + return _mm512_set_ps( + ff, fe, fd, fc, fb, fa, f9, f8, f7, f6, f5, f4, f3, f2, f1, f0); +#endif +} + +EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC); +#else + EIGEN_ALIGN64 float aux[16]; + pstore(aux, a); + half h0(aux[0]); + half h1(aux[1]); + half h2(aux[2]); + half h3(aux[3]); + half h4(aux[4]); + half h5(aux[5]); + half h6(aux[6]); + half h7(aux[7]); + half h8(aux[8]); + half h9(aux[9]); + half ha(aux[10]); + half hb(aux[11]); + half hc(aux[12]); + half hd(aux[13]); + half he(aux[14]); + half hf(aux[15]); + + return _mm256_set_epi16( + hf.x, he.x, hd.x, hc.x, hb.x, ha.x, h9.x, h8.x, + h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { + return ptrue(Packet8i(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmin<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmax<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) { + return float2half(plset<Packet16f>(static_cast<float>(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { + // in some cases Packet8i is a wrapper around __m256i, so we need to + // cast to Packet8i to call the correct overload. + return por(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a,const Packet16h& b) { + return pxor(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a,const Packet16h& b) { + return pand(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet16h& b) { + return pandnot(Packet8i(a),Packet8i(b)); +} + +template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) { + return _mm256_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) { + return float2half(pround<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) { + return float2half(print<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) { + return float2half(pceil<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) { + return float2half(pfloor<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + return Pack32To16(pcmp_eq(af, bf)); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { + Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); + return _mm256_xor_si256(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = padd(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = psub(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = pmul(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = pdiv(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) { + Packet16f from_float = half2float(from); + return half(predux(from_float)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) { + Packet8h lane0 = _mm256_extractf128_si256(a, 0); + Packet8h lane1 = _mm256_extractf128_si256(a, 1); + return padd<Packet8h>(lane0, lane1); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_max<Packet16f>(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_min<Packet16f>(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) { + Packet16f from_float = half2float(from); + return half(predux_mul(from_float)); +} + +template<> EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a,1),m)), + _mm_shuffle_epi8(_mm256_extractf128_si256(a,0),m), 1); +} + +template<> EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) +{ + return _mm256_set_epi16( + from[15*stride].x, from[14*stride].x, from[13*stride].x, from[12*stride].x, + from[11*stride].x, from[10*stride].x, from[9*stride].x, from[8*stride].x, + from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, + from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); +} + +template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) +{ + EIGEN_ALIGN64 half aux[16]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock<Packet16h,16>& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + __m256i e = kernel.packet[4]; + __m256i f = kernel.packet[5]; + __m256i g = kernel.packet[6]; + __m256i h = kernel.packet[7]; + __m256i i = kernel.packet[8]; + __m256i j = kernel.packet[9]; + __m256i k = kernel.packet[10]; + __m256i l = kernel.packet[11]; + __m256i m = kernel.packet[12]; + __m256i n = kernel.packet[13]; + __m256i o = kernel.packet[14]; + __m256i p = kernel.packet[15]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); + + kernel.packet[0] = a_p_0; + kernel.packet[1] = a_p_1; + kernel.packet[2] = a_p_2; + kernel.packet[3] = a_p_3; + kernel.packet[4] = a_p_4; + kernel.packet[5] = a_p_5; + kernel.packet[6] = a_p_6; + kernel.packet[7] = a_p_7; + kernel.packet[8] = a_p_8; + kernel.packet[9] = a_p_9; + kernel.packet[10] = a_p_a; + kernel.packet[11] = a_p_b; + kernel.packet[12] = a_p_c; + kernel.packet[13] = a_p_d; + kernel.packet[14] = a_p_e; + kernel.packet[15] = a_p_f; +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock<Packet16h,8>& kernel) { + EIGEN_ALIGN64 half in[8][16]; + pstore<half>(in[0], kernel.packet[0]); + pstore<half>(in[1], kernel.packet[1]); + pstore<half>(in[2], kernel.packet[2]); + pstore<half>(in[3], kernel.packet[3]); + pstore<half>(in[4], kernel.packet[4]); + pstore<half>(in[5], kernel.packet[5]); + pstore<half>(in[6], kernel.packet[6]); + pstore<half>(in[7], kernel.packet[7]); + + EIGEN_ALIGN64 half out[8][16]; + + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + out[i][j] = in[j][2*i]; + } + for (int j = 0; j < 8; ++j) { + out[i][j+8] = in[j][2*i+1]; + } + } + + kernel.packet[0] = pload<Packet16h>(out[0]); + kernel.packet[1] = pload<Packet16h>(out[1]); + kernel.packet[2] = pload<Packet16h>(out[2]); + kernel.packet[3] = pload<Packet16h>(out[3]); + kernel.packet[4] = pload<Packet16h>(out[4]); + kernel.packet[5] = pload<Packet16h>(out[5]); + kernel.packet[6] = pload<Packet16h>(out[6]); + kernel.packet[7] = pload<Packet16h>(out[7]); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock<Packet16h,4>& kernel) { + EIGEN_ALIGN64 half in[4][16]; + pstore<half>(in[0], kernel.packet[0]); + pstore<half>(in[1], kernel.packet[1]); + pstore<half>(in[2], kernel.packet[2]); + pstore<half>(in[3], kernel.packet[3]); + + EIGEN_ALIGN64 half out[4][16]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][4*i]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+4] = in[j][4*i+1]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+8] = in[j][4*i+2]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+12] = in[j][4*i+3]; + } + } + + kernel.packet[0] = pload<Packet16h>(out[0]); + kernel.packet[1] = pload<Packet16h>(out[1]); + kernel.packet[2] = pload<Packet16h>(out[2]); + kernel.packet[3] = pload<Packet16h>(out[3]); +} + +template <> struct is_arithmetic<Packet16bf> { enum { value = true }; }; + +template <> +struct packet_traits<bfloat16> : default_packet_traits { + typedef Packet16bf type; + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + HasBlend = 0, + HasInsert = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) +#ifdef EIGEN_VECTORIZE_AVX512DQ + HasLog = 1, // Currently fails test with bad accuracy. + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasBessel = 1, +#endif + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasCmp = 1, + HasDiv = 1 + }; +}; + +template <> +struct unpacket_traits<Packet16bf> +{ + typedef bfloat16 type; + enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; + typedef Packet8bf half; +}; + +template <> +EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) { + return _mm256_set1_epi16(from.value); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) { + bfloat16 t; + t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0)); + return t; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) { + return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) { + return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); +} + +template <> +EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, + const Packet16bf& from) { + _mm256_store_si256(reinterpret_cast<__m256i*>(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, + const Packet16bf& from) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploaddup<Packet16bf>(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + unsigned short e = from[4].value; + unsigned short f = from[5].value; + unsigned short g = from[6].value; + unsigned short h = from[7].value; + return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploadquad(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); +} + +EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. +EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { + Packet16bf r; + +#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1) + // Since GCC 10.1 supports avx512bf16 and C style explicit cast + // (C++ static_cast is not supported yet), do converion via intrinsic + // and register path for performance. + r = (__m256i)(_mm512_cvtneps_pbh(a)); + +#else + __m512i t; + __m512i input = _mm512_castps_si512(a); + __m512i nan = _mm512_set1_epi32(0x7fc0); + + // uint32_t lsb = (input >> 16) & 1; + t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm512_add_epi32(t, input); + // input = input >> 16; + t = _mm512_srli_epi32(t, 16); + + // Check NaN before converting back to bf16 + __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + + t = _mm512_mask_blend_epi32(mask, nan, t); + // output.value = static_cast<uint16_t>(input); + r = _mm512_cvtepi32_epi16(t); +#endif // EIGEN_VECTORIZE_AVX512BF16 + + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) { + return ptrue<Packet8i>(a); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) { + return por<Packet8i>(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { + return pxor<Packet8i>(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { + return pand<Packet8i>(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, + const Packet16bf& b) { + return pandnot<Packet8i>(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, + const Packet16bf& a, + const Packet16bf& b) { + // Input mask is expected to be all 0/1, handle it with 8-bit + // intrinsic for performance. + return _mm256_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a) +{ + return F32ToBf16(pround<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(print<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { + Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); + return _mm256_xor_si256(a, sign_mask); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) { + return F32ToBf16(plset<Packet16f>(static_cast<float>(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) { + Packet8bf lane0 = _mm256_extractf128_si256(a, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a, 1); + return padd<Packet8bf>(lane0, lane1); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) { + return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) { + return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) { + return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) { + return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { + __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1, + 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + + Packet16bf res; + // Swap hi and lo first because shuffle is in 128-bit lanes. + res = _mm256_permute2x128_si256(a, a, 1); + // Shuffle 8-bit values in src within 2*128-bit lanes. + return _mm256_shuffle_epi8(res, m); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from, + Index stride) { + return _mm256_set_epi16( + from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value, + from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value, + from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, + from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); +} + +template <> +EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to, + const Packet16bf& from, + Index stride) { + EIGEN_ALIGN64 bfloat16 aux[16]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + __m256i e = kernel.packet[4]; + __m256i f = kernel.packet[5]; + __m256i g = kernel.packet[6]; + __m256i h = kernel.packet[7]; + __m256i i = kernel.packet[8]; + __m256i j = kernel.packet[9]; + __m256i k = kernel.packet[10]; + __m256i l = kernel.packet[11]; + __m256i m = kernel.packet[12]; + __m256i n = kernel.packet[13]; + __m256i o = kernel.packet[14]; + __m256i p = kernel.packet[15]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31); + kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); } } // end namespace internal |