diff options
Diffstat (limited to 'Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h')
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h | 72 |
1 files changed, 46 insertions, 26 deletions
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h index a45238d69..61396dbdf 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h @@ -40,20 +40,22 @@ namespace internal { /* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */ -#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ -struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor> \ +struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \ {\ \ static void run( \ Index rows, Index cols, \ const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \ - EIGTYPE* res, Index resStride, \ + EIGTYPE* res, Index resIncr, Index resStride, \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ char side='L', uplo='L'; \ BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ @@ -81,25 +83,27 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ } else b = _rhs; \ \ - BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ -struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor> \ +struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \ {\ static void run( \ Index rows, Index cols, \ const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \ - EIGTYPE* res, Index resStride, \ + EIGTYPE* res, Index resIncr, Index resStride, \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ char side='L', uplo='L'; \ BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ @@ -144,33 +148,41 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ } \ \ - BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -EIGEN_BLAS_SYMM_L(double, double, d, d) -EIGEN_BLAS_SYMM_L(float, float, f, s) -EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z) -EIGEN_BLAS_HEMM_L(scomplex, float, cf, c) - +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_SYMM_L(double, double, d, dsymm) +EIGEN_BLAS_SYMM_L(float, float, f, ssymm) +EIGEN_BLAS_HEMM_L(dcomplex, MKL_Complex16, cd, zhemm) +EIGEN_BLAS_HEMM_L(scomplex, MKL_Complex8, cf, chemm) +#else +EIGEN_BLAS_SYMM_L(double, double, d, dsymm_) +EIGEN_BLAS_SYMM_L(float, float, f, ssymm_) +EIGEN_BLAS_HEMM_L(dcomplex, double, cd, zhemm_) +EIGEN_BLAS_HEMM_L(scomplex, float, cf, chemm_) +#endif /* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */ -#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ -struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor> \ +struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \ {\ \ static void run( \ Index rows, Index cols, \ const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \ - EIGTYPE* res, Index resStride, \ + EIGTYPE* res, Index resIncr, Index resStride, \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ char side='R', uplo='L'; \ BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ @@ -197,25 +209,27 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ } else b = _lhs; \ \ - BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ template <typename Index, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ -struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor> \ +struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \ {\ static void run( \ Index rows, Index cols, \ const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \ - EIGTYPE* res, Index resStride, \ + EIGTYPE* res, Index resIncr, Index resStride, \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ char side='R', uplo='L'; \ BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ @@ -259,15 +273,21 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ } \ \ - BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ } \ }; -EIGEN_BLAS_SYMM_R(double, double, d, d) -EIGEN_BLAS_SYMM_R(float, float, f, s) -EIGEN_BLAS_HEMM_R(dcomplex, double, cd, z) -EIGEN_BLAS_HEMM_R(scomplex, float, cf, c) - +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_SYMM_R(double, double, d, dsymm) +EIGEN_BLAS_SYMM_R(float, float, f, ssymm) +EIGEN_BLAS_HEMM_R(dcomplex, MKL_Complex16, cd, zhemm) +EIGEN_BLAS_HEMM_R(scomplex, MKL_Complex8, cf, chemm) +#else +EIGEN_BLAS_SYMM_R(double, double, d, dsymm_) +EIGEN_BLAS_SYMM_R(float, float, f, ssymm_) +EIGEN_BLAS_HEMM_R(dcomplex, double, cd, zhemm_) +EIGEN_BLAS_HEMM_R(scomplex, float, cf, chemm_) +#endif } // end namespace internal } // end namespace Eigen |