diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixVector.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixVector.h | 865 |
1 files changed, 382 insertions, 483 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 3c1a7fc40..dfb6aebce 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -14,11 +14,57 @@ namespace Eigen { namespace internal { +enum GEMVPacketSizeType { + GEMVPacketFull = 0, + GEMVPacketHalf, + GEMVPacketQuarter +}; + +template <int N, typename T1, typename T2, typename T3> +struct gemv_packet_cond { typedef T3 type; }; + +template <typename T1, typename T2, typename T3> +struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; }; + +template <typename T1, typename T2, typename T3> +struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; }; + +template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull> +class gemv_traits +{ + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; + +#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ + typedef typename gemv_packet_cond<packet_size, \ + typename packet_traits<name ## Scalar>::type, \ + typename packet_traits<name ## Scalar>::half, \ + typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \ + prefix ## name ## Packet + + PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Res, _PacketSize); +#undef PACKET_DECL_COND_PREFIX + +public: + enum { + Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && + unpacket_traits<_RhsPacket>::vectorizable && + int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size), + LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1, + RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1, + ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1 + }; + + typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; + typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; + typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; +}; + + /* Optimized col-major matrix * vector product: - * This algorithm processes 4 columns at onces that allows to both reduce - * the number of load/stores of the result by a factor 4 and to reduce - * the instruction dependency. Moreover, we know that all bands have the - * same alignment pattern. + * This algorithm processes the matrix per vertical panels, + * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments. * * Mixing type logic: C += alpha * A * B * | A | B |alpha| comments @@ -27,56 +73,30 @@ namespace internal { * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp * |cplx |real |real | optimal case, vectorization possible via real-cplx mul * - * Accesses to the matrix coefficients follow the following logic: - * - * - if all columns have the same alignment then - * - if the columns have the same alignment as the result vector, then easy! (-> AllAligned case) - * - otherwise perform unaligned loads only (-> NoneAligned case) - * - otherwise - * - if even columns have the same alignment then - * // odd columns are guaranteed to have the same alignment too - * - if even or odd columns have the same alignment as the result, then - * // for a register size of 2 scalars, this is guarantee to be the case (e.g., SSE with double) - * - perform half aligned and half unaligned loads (-> EvenAligned case) - * - otherwise perform unaligned loads only (-> NoneAligned case) - * - otherwise, if the register size is 4 scalars (e.g., SSE with float) then - * - one over 4 consecutive columns is guaranteed to be aligned with the result vector, - * perform simple aligned loads for this column and aligned loads plus re-alignment for the other. (-> FirstAligned case) - * // this re-alignment is done by the palign function implemented for SSE in Eigen/src/Core/arch/SSE/PacketMath.h - * - otherwise, - * // if we get here, this means the register size is greater than 4 (e.g., AVX with floats), - * // we currently fall back to the NoneAligned case - * * The same reasoning apply for the transposed case. - * - * The last case (PacketSize>4) could probably be improved by generalizing the FirstAligned case, but since we do not support AVX yet... - * One might also wonder why in the EvenAligned case we perform unaligned loads instead of using the aligned-loads plus re-alignment - * strategy as in the FirstAligned case. The reason is that we observed that unaligned loads on a 8 byte boundary are not too slow - * compared to unaligned loads on a 4 byte boundary. - * */ template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version> { + typedef gemv_traits<LhsScalar,RhsScalar> Traits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits; + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; -enum { - Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable - && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size), - LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, - RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, - ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1 -}; + typedef typename Traits::LhsPacket LhsPacket; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename packet_traits<LhsScalar>::type _LhsPacket; -typedef typename packet_traits<RhsScalar>::type _RhsPacket; -typedef typename packet_traits<ResScalar>::type _ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; -typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; -typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; -typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; -EIGEN_DONT_INLINE static void run( +EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, @@ -85,244 +105,187 @@ EIGEN_DONT_INLINE static void run( }; template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> -EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run( +EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run( Index rows, Index cols, - const LhsMapper& lhs, + const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha) { EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr==1); - #ifdef _EIGEN_ACCUMULATE_PACKETS - #error _EIGEN_ACCUMULATE_PACKETS has already been defined - #endif - #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) \ - pstore(&res[j], \ - padd(pload<ResPacket>(&res[j]), \ - padd( \ - padd(pcj.pmul(lhs0.template load<LhsPacket, Alignment0>(j), ptmp0), \ - pcj.pmul(lhs1.template load<LhsPacket, Alignment13>(j), ptmp1)), \ - padd(pcj.pmul(lhs2.template load<LhsPacket, Alignment2>(j), ptmp2), \ - pcj.pmul(lhs3.template load<LhsPacket, Alignment13>(j), ptmp3)) ))) - - typedef typename LhsMapper::VectorMapper LhsScalars; + + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate propoer code. + LhsMapper lhs(alhs); conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; - if(ConjugateRhs) - alpha = numext::conj(alpha); - - enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned }; - const Index columnsAtOnce = 4; - const Index peels = 2; - const Index LhsPacketAlignedMask = LhsPacketSize-1; - const Index ResPacketAlignedMask = ResPacketSize-1; -// const Index PeelAlignedMask = ResPacketSize*peels-1; - const Index size = rows; + conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half; + conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter; const Index lhsStride = lhs.stride(); - - // How many coeffs of the result do we have to skip to be aligned. - // Here we assume data are at least aligned on the base scalar type. - Index alignedStart = internal::first_default_aligned(res,size); - Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0; - const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1; - - const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0; - Index alignmentPattern = alignmentStep==0 ? AllAligned - : alignmentStep==(LhsPacketSize/2) ? EvenAligned - : FirstAligned; - - // we cannot assume the first element is aligned because of sub-matrices - const Index lhsAlignmentOffset = lhs.firstAligned(size); - - // find how many columns do we have to skip to be aligned with the result (if possible) - Index skipColumns = 0; - // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats) - if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (UIntPtr(res)%sizeof(ResScalar)) ) - { - alignedSize = 0; - alignedStart = 0; - alignmentPattern = NoneAligned; - } - else if(LhsPacketSize > 4) - { - // TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4. - // Currently, it seems to be better to perform unaligned loads anyway - alignmentPattern = NoneAligned; - } - else if (LhsPacketSize>1) + // TODO: for padded aligned inputs, we could enable aligned reads + enum { LhsAlignment = Unaligned, + ResPacketSize = Traits::ResPacketSize, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSize = Traits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; + + const Index n8 = rows-8*ResPacketSize+1; + const Index n4 = rows-4*ResPacketSize+1; + const Index n3 = rows-3*ResPacketSize+1; + const Index n2 = rows-2*ResPacketSize+1; + const Index n1 = rows-1*ResPacketSize+1; + const Index n_half = rows-1*ResPacketSizeHalf+1; + const Index n_quarter = rows-1*ResPacketSizeQuarter+1; + + // TODO: improve the following heuristic: + const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4); + ResPacket palpha = pset1<ResPacket>(alpha); + ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha); + ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha); + + for(Index j2=0; j2<cols; j2+=block_cols) { - // eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || size<LhsPacketSize); - - while (skipColumns<LhsPacketSize && - alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize)) - ++skipColumns; - if (skipColumns==LhsPacketSize) + Index jend = numext::mini(j2+block_cols,cols); + Index i=0; + for(; i<n8; i+=ResPacketSize*8) { - // nothing can be aligned, no need to skip any column - alignmentPattern = NoneAligned; - skipColumns = 0; + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)), + c2 = pset1<ResPacket>(ResScalar(0)), + c3 = pset1<ResPacket>(ResScalar(0)), + c4 = pset1<ResPacket>(ResScalar(0)), + c5 = pset1<ResPacket>(ResScalar(0)), + c6 = pset1<ResPacket>(ResScalar(0)), + c7 = pset1<ResPacket>(ResScalar(0)); + + for(Index j=j2; j<jend; j+=1) + { + RhsPacket b0 = pset1<RhsPacket>(rhs(j,0)); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1); + c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2); + c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3); + c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*4,j),b0,c4); + c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*5,j),b0,c5); + c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*6,j),b0,c6); + c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*7,j),b0,c7); + } + pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); + pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1))); + pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2))); + pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3))); + pstoreu(res+i+ResPacketSize*4, pmadd(c4,palpha,ploadu<ResPacket>(res+i+ResPacketSize*4))); + pstoreu(res+i+ResPacketSize*5, pmadd(c5,palpha,ploadu<ResPacket>(res+i+ResPacketSize*5))); + pstoreu(res+i+ResPacketSize*6, pmadd(c6,palpha,ploadu<ResPacket>(res+i+ResPacketSize*6))); + pstoreu(res+i+ResPacketSize*7, pmadd(c7,palpha,ploadu<ResPacket>(res+i+ResPacketSize*7))); } - else + if(i<n4) { - skipColumns = (std::min)(skipColumns,cols); - // note that the skiped columns are processed later. - } + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)), + c2 = pset1<ResPacket>(ResScalar(0)), + c3 = pset1<ResPacket>(ResScalar(0)); - /* eigen_internal_assert( (alignmentPattern==NoneAligned) - || (skipColumns + columnsAtOnce >= cols) - || LhsPacketSize > size - || (size_t(firstLhs+alignedStart+lhsStride*skipColumns)%sizeof(LhsPacket))==0);*/ - } - else if(Vectorizable) - { - alignedStart = 0; - alignedSize = size; - alignmentPattern = AllAligned; - } - - const Index offset1 = (FirstAligned && alignmentStep==1)?3:1; - const Index offset3 = (FirstAligned && alignmentStep==1)?1:3; + for(Index j=j2; j<jend; j+=1) + { + RhsPacket b0 = pset1<RhsPacket>(rhs(j,0)); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1); + c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2); + c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3); + } + pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); + pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1))); + pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2))); + pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3))); - Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns; - for (Index i=skipColumns; i<columnBound; i+=columnsAtOnce) - { - RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(i, 0)), - ptmp1 = pset1<RhsPacket>(alpha*rhs(i+offset1, 0)), - ptmp2 = pset1<RhsPacket>(alpha*rhs(i+2, 0)), - ptmp3 = pset1<RhsPacket>(alpha*rhs(i+offset3, 0)); + i+=ResPacketSize*4; + } + if(i<n3) + { + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)), + c2 = pset1<ResPacket>(ResScalar(0)); - // this helps a lot generating better binary code - const LhsScalars lhs0 = lhs.getVectorMapper(0, i+0), lhs1 = lhs.getVectorMapper(0, i+offset1), - lhs2 = lhs.getVectorMapper(0, i+2), lhs3 = lhs.getVectorMapper(0, i+offset3); + for(Index j=j2; j<jend; j+=1) + { + RhsPacket b0 = pset1<RhsPacket>(rhs(j,0)); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1); + c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2); + } + pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); + pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1))); + pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2))); - if (Vectorizable) + i+=ResPacketSize*3; + } + if(i<n2) { - /* explicit vectorization */ - // process initial unaligned coeffs - for (Index j=0; j<alignedStart; ++j) + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)); + + for(Index j=j2; j<jend; j+=1) { - res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]); - res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]); - res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]); - res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]); + RhsPacket b0 = pset1<RhsPacket>(rhs(j,0)); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1); } - - if (alignedSize>alignedStart) + pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); + pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1))); + i+=ResPacketSize*2; + } + if(i<n1) + { + ResPacket c0 = pset1<ResPacket>(ResScalar(0)); + for(Index j=j2; j<jend; j+=1) { - switch(alignmentPattern) - { - case AllAligned: - for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned); - break; - case EvenAligned: - for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned); - break; - case FirstAligned: - { - Index j = alignedStart; - if(peels>1) - { - LhsPacket A00, A01, A02, A03, A10, A11, A12, A13; - ResPacket T0, T1; - - A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1); - A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2); - A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3); - - for (; j<peeledSize; j+=peels*ResPacketSize) - { - A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11); - A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12); - A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13); - - A00 = lhs0.template load<LhsPacket, Aligned>(j); - A10 = lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize); - T0 = pcj.pmadd(A00, ptmp0, pload<ResPacket>(&res[j])); - T1 = pcj.pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize])); - - T0 = pcj.pmadd(A01, ptmp1, T0); - A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01); - T0 = pcj.pmadd(A02, ptmp2, T0); - A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02); - T0 = pcj.pmadd(A03, ptmp3, T0); - pstore(&res[j],T0); - A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03); - T1 = pcj.pmadd(A11, ptmp1, T1); - T1 = pcj.pmadd(A12, ptmp2, T1); - T1 = pcj.pmadd(A13, ptmp3, T1); - pstore(&res[j+ResPacketSize],T1); - } - } - for (; j<alignedSize; j+=ResPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned); - break; - } - default: - for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned); - break; - } + RhsPacket b0 = pset1<RhsPacket>(rhs(j,0)); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0); } - } // end explicit vectorization - - /* process remaining coeffs (or all if there is no explicit vectorization) */ - for (Index j=alignedSize; j<size; ++j) + pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0))); + i+=ResPacketSize; + } + if(HasHalf && i<n_half) { - res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]); - res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]); - res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]); - res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]); + ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0)); + for(Index j=j2; j<jend; j+=1) + { + RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0)); + c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0); + } + pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0))); + i+=ResPacketSizeHalf; } - } - - // process remaining first and last columns (at most columnsAtOnce-1) - Index end = cols; - Index start = columnBound; - do - { - for (Index k=start; k<end; ++k) + if(HasQuarter && i<n_quarter) { - RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(k, 0)); - const LhsScalars lhs0 = lhs.getVectorMapper(0, k); - - if (Vectorizable) + ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0)); + for(Index j=j2; j<jend; j+=1) { - /* explicit vectorization */ - // process first unaligned result's coeffs - for (Index j=0; j<alignedStart; ++j) - res[j] += cj.pmul(lhs0(j), pfirst(ptmp0)); - // process aligned result's coeffs - if (lhs0.template aligned<LhsPacket>(alignedStart)) - for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize) - pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(i), ptmp0, pload<ResPacket>(&res[i]))); - else - for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize) - pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(i), ptmp0, pload<ResPacket>(&res[i]))); + RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0)); + c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0); } - - // process remaining scalars (or all if no explicit vectorization) - for (Index i=alignedSize; i<size; ++i) - res[i] += cj.pmul(lhs0(i), pfirst(ptmp0)); + pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0))); + i+=ResPacketSizeQuarter; } - if (skipColumns) + for(;i<rows;++i) { - start = 0; - end = skipColumns; - skipColumns = 0; + ResScalar c0(0); + for(Index j=j2; j<jend; j+=1) + c0 += cj.pmul(lhs(i,j), rhs(j,0)); + res[i] += alpha*c0; } - else - break; - } while(Vectorizable); - #undef _EIGEN_ACCUMULATE_PACKETS + } } /* Optimized row-major matrix * vector product: - * This algorithm processes 4 rows at onces that allows to both reduce + * This algorithm processes 4 rows at once that allows to both reduce * the number of load/stores of the result by a factor 4 and to reduce * the instruction dependency. Moreover, we know that all bands have the * same alignment pattern. @@ -334,25 +297,25 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,C template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version> { -typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; - -enum { - Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable - && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size), - LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, - RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, - ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1 -}; + typedef gemv_traits<LhsScalar,RhsScalar> Traits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits; + typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits; + + typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; + + typedef typename Traits::LhsPacket LhsPacket; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename packet_traits<LhsScalar>::type _LhsPacket; -typedef typename packet_traits<RhsScalar>::type _RhsPacket; -typedef typename packet_traits<ResScalar>::type _ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; -typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; -typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; -typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; -EIGEN_DONT_INLINE static void run( +EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, @@ -361,255 +324,191 @@ EIGEN_DONT_INLINE static void run( }; template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version> -EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run( +EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run( Index rows, Index cols, - const LhsMapper& lhs, + const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { - eigen_internal_assert(rhs.stride()==1); - - #ifdef _EIGEN_ACCUMULATE_PACKETS - #error _EIGEN_ACCUMULATE_PACKETS has already been defined - #endif - - #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\ - RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \ - ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \ - ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \ - ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \ - ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); } + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate propoer code. + LhsMapper lhs(alhs); + eigen_internal_assert(rhs.stride()==1); conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; - - typedef typename LhsMapper::VectorMapper LhsScalars; - - enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 }; - const Index rowsAtOnce = 4; - const Index peels = 2; - const Index RhsPacketAlignedMask = RhsPacketSize-1; - const Index LhsPacketAlignedMask = LhsPacketSize-1; - const Index depth = cols; - const Index lhsStride = lhs.stride(); - - // How many coeffs of the result do we have to skip to be aligned. - // Here we assume data are at least aligned on the base scalar type - // if that's not the case then vectorization is discarded, see below. - Index alignedStart = rhs.firstAligned(depth); - Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0; - const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1; - - const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0; - Index alignmentPattern = alignmentStep==0 ? AllAligned - : alignmentStep==(LhsPacketSize/2) ? EvenAligned - : FirstAligned; - - // we cannot assume the first element is aligned because of sub-matrices - const Index lhsAlignmentOffset = lhs.firstAligned(depth); - const Index rhsAlignmentOffset = rhs.firstAligned(rows); - - // find how many rows do we have to skip to be aligned with rhs (if possible) - Index skipRows = 0; - // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats) - if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || - (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) || - (rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) ) - { - alignedSize = 0; - alignedStart = 0; - alignmentPattern = NoneAligned; - } - else if(LhsPacketSize > 4) - { - // TODO: extend the code to support aligned loads whenever possible when LhsPacketSize > 4. - alignmentPattern = NoneAligned; - } - else if (LhsPacketSize>1) + conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half; + conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter; + + // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large, + // processing 8 rows at once might be counter productive wrt cache. + const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7; + const Index n4 = rows-3; + const Index n2 = rows-1; + + // TODO: for padded aligned inputs, we could enable aligned reads + enum { LhsAlignment = Unaligned, + ResPacketSize = Traits::ResPacketSize, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSize = Traits::LhsPacketSize, + LhsPacketSizeHalf = HalfTraits::LhsPacketSize, + LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; + + Index i=0; + for(; i<n8; i+=8) { - // eigen_internal_assert(size_t(firstLhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth<LhsPacketSize); - - while (skipRows<LhsPacketSize && - alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize)) - ++skipRows; - if (skipRows==LhsPacketSize) + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)), + c2 = pset1<ResPacket>(ResScalar(0)), + c3 = pset1<ResPacket>(ResScalar(0)), + c4 = pset1<ResPacket>(ResScalar(0)), + c5 = pset1<ResPacket>(ResScalar(0)), + c6 = pset1<ResPacket>(ResScalar(0)), + c7 = pset1<ResPacket>(ResScalar(0)); + + Index j=0; + for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { - // nothing can be aligned, no need to skip any column - alignmentPattern = NoneAligned; - skipRows = 0; + RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0); + + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1); + c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2); + c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3); + c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4); + c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5); + c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6); + c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7); } - else + ResScalar cc0 = predux(c0); + ResScalar cc1 = predux(c1); + ResScalar cc2 = predux(c2); + ResScalar cc3 = predux(c3); + ResScalar cc4 = predux(c4); + ResScalar cc5 = predux(c5); + ResScalar cc6 = predux(c6); + ResScalar cc7 = predux(c7); + for(; j<cols; ++j) { - skipRows = (std::min)(skipRows,Index(rows)); - // note that the skiped columns are processed later. + RhsScalar b0 = rhs(j,0); + + cc0 += cj.pmul(lhs(i+0,j), b0); + cc1 += cj.pmul(lhs(i+1,j), b0); + cc2 += cj.pmul(lhs(i+2,j), b0); + cc3 += cj.pmul(lhs(i+3,j), b0); + cc4 += cj.pmul(lhs(i+4,j), b0); + cc5 += cj.pmul(lhs(i+5,j), b0); + cc6 += cj.pmul(lhs(i+6,j), b0); + cc7 += cj.pmul(lhs(i+7,j), b0); } - /* eigen_internal_assert( alignmentPattern==NoneAligned - || LhsPacketSize==1 - || (skipRows + rowsAtOnce >= rows) - || LhsPacketSize > depth - || (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/ + res[(i+0)*resIncr] += alpha*cc0; + res[(i+1)*resIncr] += alpha*cc1; + res[(i+2)*resIncr] += alpha*cc2; + res[(i+3)*resIncr] += alpha*cc3; + res[(i+4)*resIncr] += alpha*cc4; + res[(i+5)*resIncr] += alpha*cc5; + res[(i+6)*resIncr] += alpha*cc6; + res[(i+7)*resIncr] += alpha*cc7; } - else if(Vectorizable) + for(; i<n4; i+=4) { - alignedStart = 0; - alignedSize = depth; - alignmentPattern = AllAligned; - } - - const Index offset1 = (FirstAligned && alignmentStep==1)?3:1; - const Index offset3 = (FirstAligned && alignmentStep==1)?1:3; + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)), + c2 = pset1<ResPacket>(ResScalar(0)), + c3 = pset1<ResPacket>(ResScalar(0)); - Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows; - for (Index i=skipRows; i<rowBound; i+=rowsAtOnce) - { - // FIXME: what is the purpose of this EIGEN_ALIGN_DEFAULT ?? - EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0); - ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0); - - // this helps the compiler generating good binary code - const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0), - lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0); + Index j=0; + for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) + { + RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0); - if (Vectorizable) + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1); + c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2); + c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3); + } + ResScalar cc0 = predux(c0); + ResScalar cc1 = predux(c1); + ResScalar cc2 = predux(c2); + ResScalar cc3 = predux(c3); + for(; j<cols; ++j) { - /* explicit vectorization */ - ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)), - ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0)); + RhsScalar b0 = rhs(j,0); - // process initial unaligned coeffs - // FIXME this loop get vectorized by the compiler ! - for (Index j=0; j<alignedStart; ++j) - { - RhsScalar b = rhs(j, 0); - tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b); - tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b); - } + cc0 += cj.pmul(lhs(i+0,j), b0); + cc1 += cj.pmul(lhs(i+1,j), b0); + cc2 += cj.pmul(lhs(i+2,j), b0); + cc3 += cj.pmul(lhs(i+3,j), b0); + } + res[(i+0)*resIncr] += alpha*cc0; + res[(i+1)*resIncr] += alpha*cc1; + res[(i+2)*resIncr] += alpha*cc2; + res[(i+3)*resIncr] += alpha*cc3; + } + for(; i<n2; i+=2) + { + ResPacket c0 = pset1<ResPacket>(ResScalar(0)), + c1 = pset1<ResPacket>(ResScalar(0)); - if (alignedSize>alignedStart) - { - switch(alignmentPattern) - { - case AllAligned: - for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Aligned,Aligned); - break; - case EvenAligned: - for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Aligned); - break; - case FirstAligned: - { - Index j = alignedStart; - if (peels>1) - { - /* Here we proccess 4 rows with with two peeled iterations to hide - * the overhead of unaligned loads. Moreover unaligned loads are handled - * using special shift/move operations between the two aligned packets - * overlaping the desired unaligned packet. This is *much* more efficient - * than basic unaligned loads. - */ - LhsPacket A01, A02, A03, A11, A12, A13; - A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1); - A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2); - A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3); - - for (; j<peeledSize; j+=peels*RhsPacketSize) - { - RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); - A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11); - A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12); - A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13); - - ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0); - ptmp1 = pcj.pmadd(A01, b, ptmp1); - A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01); - ptmp2 = pcj.pmadd(A02, b, ptmp2); - A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02); - ptmp3 = pcj.pmadd(A03, b, ptmp3); - A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03); - - b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0); - ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0); - ptmp1 = pcj.pmadd(A11, b, ptmp1); - ptmp2 = pcj.pmadd(A12, b, ptmp2); - ptmp3 = pcj.pmadd(A13, b, ptmp3); - } - } - for (; j<alignedSize; j+=RhsPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Aligned,Unaligned,Unaligned); - break; - } - default: - for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize) - _EIGEN_ACCUMULATE_PACKETS(Unaligned,Unaligned,Unaligned); - break; - } - tmp0 += predux(ptmp0); - tmp1 += predux(ptmp1); - tmp2 += predux(ptmp2); - tmp3 += predux(ptmp3); - } - } // end explicit vectorization + Index j=0; + for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) + { + RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0); - // process remaining coeffs (or all if no explicit vectorization) - // FIXME this loop get vectorized by the compiler ! - for (Index j=alignedSize; j<depth; ++j) + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0); + c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1); + } + ResScalar cc0 = predux(c0); + ResScalar cc1 = predux(c1); + for(; j<cols; ++j) { - RhsScalar b = rhs(j, 0); - tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b); - tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b); + RhsScalar b0 = rhs(j,0); + + cc0 += cj.pmul(lhs(i+0,j), b0); + cc1 += cj.pmul(lhs(i+1,j), b0); } - res[i*resIncr] += alpha*tmp0; - res[(i+offset1)*resIncr] += alpha*tmp1; - res[(i+2)*resIncr] += alpha*tmp2; - res[(i+offset3)*resIncr] += alpha*tmp3; + res[(i+0)*resIncr] += alpha*cc0; + res[(i+1)*resIncr] += alpha*cc1; } - - // process remaining first and last rows (at most columnsAtOnce-1) - Index end = rows; - Index start = rowBound; - do + for(; i<rows; ++i) { - for (Index i=start; i<end; ++i) + ResPacket c0 = pset1<ResPacket>(ResScalar(0)); + ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0)); + ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0)); + Index j=0; + for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { - EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0); - ResPacket ptmp0 = pset1<ResPacket>(tmp0); - const LhsScalars lhs0 = lhs.getVectorMapper(i, 0); - // process first unaligned result's coeffs - // FIXME this loop get vectorized by the compiler ! - for (Index j=0; j<alignedStart; ++j) - tmp0 += cj.pmul(lhs0(j), rhs(j, 0)); - - if (alignedSize>alignedStart) - { - // process aligned rhs coeffs - if (lhs0.template aligned<LhsPacket>(alignedStart)) - for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize) - ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0); - else - for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize) - ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0); - tmp0 += predux(ptmp0); - } - - // process remaining scalars - // FIXME this loop get vectorized by the compiler ! - for (Index j=alignedSize; j<depth; ++j) - tmp0 += cj.pmul(lhs0(j), rhs(j, 0)); - res[i*resIncr] += alpha*tmp0; + RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0); + c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0); } - if (skipRows) + ResScalar cc0 = predux(c0); + if (HasHalf) { + for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf) + { + RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0); + c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h); + } + cc0 += predux(c0_h); + } + if (HasQuarter) { + for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter) + { + RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0); + c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q); + } + cc0 += predux(c0_q); + } + for(; j<cols; ++j) { - start = 0; - end = skipRows; - skipRows = 0; + cc0 += cj.pmul(lhs(i,j), rhs(j,0)); } - else - break; - } while(Vectorizable); - - #undef _EIGEN_ACCUMULATE_PACKETS + res[i*resIncr] += alpha*cc0; + } } } // end namespace internal |