diff options
Diffstat (limited to 'Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h')
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h | 65 |
1 files changed, 40 insertions, 25 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h index aecded6bb..a98d12e4a 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h @@ -46,7 +46,7 @@ template <typename Scalar, typename Index, struct product_triangular_matrix_matrix_trmm : product_triangular_matrix_matrix<Scalar,Index,Mode, LhsIsTriangular,LhsStorageOrder,ConjugateLhs, - RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {}; + RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {}; // try to go to BLAS specialization @@ -55,13 +55,15 @@ template <typename Index, int Mode, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ - LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ + LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \ static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ - const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ + const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ - _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ + _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ } \ }; @@ -75,7 +77,7 @@ EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true) EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false) // implements col-major += alpha * op(triangular) * op(general) -#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, int Mode, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ @@ -115,8 +117,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ /* Most likely no benefit to call TRMM or GEMM from BLAS */ \ product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ - LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ - _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ + LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \ + _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \ /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \ } else { \ /* Make sense to call GEMM */ \ @@ -124,8 +126,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ - general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ - rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ + general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \ + rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \ \ /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ @@ -172,7 +174,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ } \ /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ + BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -180,13 +182,20 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ } \ }; -EIGEN_BLAS_TRMM_L(double, double, d, d) -EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z) -EIGEN_BLAS_TRMM_L(float, float, f, s) -EIGEN_BLAS_TRMM_L(scomplex, float, cf, c) +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRMM_L(double, double, d, dtrmm) +EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm) +EIGEN_BLAS_TRMM_L(float, float, f, strmm) +EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm) +#else +EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_) +EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_) +EIGEN_BLAS_TRMM_L(float, float, f, strmm_) +EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_) +#endif // implements col-major += alpha * op(general) * op(triangular) -#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, int Mode, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ @@ -225,8 +234,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \ product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ - LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ - _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ + LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \ + _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \ /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \ } else { \ /* Make sense to call GEMM */ \ @@ -234,8 +243,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ - general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ - rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ + general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \ + rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \ \ /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ @@ -282,7 +291,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ } \ /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ + BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -290,11 +299,17 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ } \ }; -EIGEN_BLAS_TRMM_R(double, double, d, d) -EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z) -EIGEN_BLAS_TRMM_R(float, float, f, s) -EIGEN_BLAS_TRMM_R(scomplex, float, cf, c) - +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRMM_R(double, double, d, dtrmm) +EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm) +EIGEN_BLAS_TRMM_R(float, float, f, strmm) +EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm) +#else +EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_) +EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_) +EIGEN_BLAS_TRMM_R(float, float, f, strmm_) +EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_) +#endif } // end namespace internal } // end namespace Eigen |