diff options
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h')
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h | 76 |
1 files changed, 41 insertions, 35 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index e844e37d1..6ba0d9bdb 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -25,51 +25,54 @@ namespace internal { **********************************************************************/ // forward declarations (defined at the end of this file) -template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo> +template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo> struct tribb_kernel; /* Optimized matrix-matrix product evaluating only one triangular half */ template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, - int ResStorageOrder, int UpLo, int Version = Specialized> + int ResStorageOrder, int ResInnerStride, int UpLo, int Version = Specialized> struct general_matrix_matrix_triangular_product; // as usual if the result is row major => we transpose the product template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, - typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version> -struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo,Version> + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride, int UpLo, int Version> +struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,UpLo,Version> { typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, + const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr, Index resStride, const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking) { general_matrix_matrix_triangular_product<Index, RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, - ColMajor, UpLo==Lower?Upper:Lower> - ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking); + ColMajor, ResInnerStride, UpLo==Lower?Upper:Lower> + ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking); } }; template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, - typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version> -struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Version> + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride, int UpLo, int Version> +struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version> { typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, - const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, + const RhsScalar* _rhs, Index rhsStride, + ResScalar* _res, Index resIncr, Index resStride, const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) { typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; - typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper; + typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; LhsMapper lhs(_lhs,lhsStride); RhsMapper rhs(_rhs,rhsStride); - ResMapper res(_res, resStride); + ResMapper res(_res, resStride, resIncr); Index kc = blocking.kc(); Index mc = (std::min)(size,blocking.mc()); @@ -84,10 +87,10 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA()); ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB()); - gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; + gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs; gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs; gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp; - tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb; + tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo> sybb; for(Index k2=0; k2<depth; k2+=kc) { @@ -110,8 +113,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, (std::min)(size,i2), alpha, -1, -1, 0, 0); - - sybb(_res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha); + sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha); if (UpLo==Upper) { @@ -133,7 +135,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, // while the triangular block overlapping the diagonal is evaluated into a // small temporary buffer which is then accumulated into the result using a // triangular traversal. -template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo> +template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo> struct tribb_kernel { typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits; @@ -142,11 +144,13 @@ struct tribb_kernel enum { BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret }; - void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha) + void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha) { - typedef blas_data_mapper<ResScalar, Index, ColMajor> ResMapper; - ResMapper res(_res, resStride); - gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel; + typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; + typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper; + ResMapper res(_res, resStride, resIncr); + gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1; + gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2; Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert())); @@ -158,31 +162,32 @@ struct tribb_kernel const RhsScalar* actual_b = blockB+j*depth; if(UpLo==Upper) - gebp_kernel(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, - -1, -1, 0, 0); - + gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, + -1, -1, 0, 0); + // selfadjoint micro block { Index i = j; buffer.setZero(); // 1 - apply the kernel on the temporary buffer - gebp_kernel(ResMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha, - -1, -1, 0, 0); + gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha, + -1, -1, 0, 0); + // 2 - triangular accumulation for(Index j1=0; j1<actualBlockSize; ++j1) { - ResScalar* r = &res(i, j + j1); + typename ResMapper::LinearMapper r = res.getLinearMapper(i,j+j1); for(Index i1=UpLo==Lower ? j1 : 0; UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1) - r[i1] += buffer(i1,j1); + r(i1) += buffer(i1,j1); } } if(UpLo==Lower) { Index i = j+actualBlockSize; - gebp_kernel(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i, - depth, actualBlockSize, alpha, -1, -1, 0, 0); + gebp_kernel1(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i, + depth, actualBlockSize, alpha, -1, -1, 0, 0); } } } @@ -286,23 +291,24 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> internal::general_matrix_matrix_triangular_product<Index, typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, - IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)> + IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo&(Lower|Upper)> ::run(size, depth, &actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(), - mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking); + mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? mat.innerStride() : mat.outerStride() ) : 0), + mat.innerStride(), mat.outerStride(), actualAlpha, blocking); } }; template<typename MatrixType, unsigned int UpLo> template<typename ProductType> -TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta) +EIGEN_DEVICE_FUNC TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta) { EIGEN_STATIC_ASSERT((UpLo&UnitDiag)==0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED); eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols()); - + general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta); - + return derived(); } |