diff options
Diffstat (limited to 'Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h')
-rw-r--r-- | Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h | 52 |
1 files changed, 34 insertions, 18 deletions
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h index 88c0fb794..621194ce6 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h @@ -38,9 +38,9 @@ namespace Eigen { namespace internal { // implements LeftSide op(triangular)^-1 * general -#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \ +#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASFUNC) \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ -struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \ +struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,1> \ { \ enum { \ IsLower = (Mode&Lower) == Lower, \ @@ -51,8 +51,10 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage static void run( \ Index size, Index otherSize, \ const EIGTYPE* _tri, Index triStride, \ - EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ + EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \ + eigen_assert(otherIncr == 1); \ BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \ char side = 'L', uplo, diag='N', transa; \ /* Set alpha_ */ \ @@ -80,20 +82,26 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage } \ if (IsUnitDiag) diag='U'; \ /* call ?trsm*/ \ - BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ + BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ } \ }; -EIGEN_BLAS_TRSM_L(double, double, d) -EIGEN_BLAS_TRSM_L(dcomplex, double, z) -EIGEN_BLAS_TRSM_L(float, float, s) -EIGEN_BLAS_TRSM_L(scomplex, float, c) - +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRSM_L(double, double, dtrsm) +EIGEN_BLAS_TRSM_L(dcomplex, MKL_Complex16, ztrsm) +EIGEN_BLAS_TRSM_L(float, float, strsm) +EIGEN_BLAS_TRSM_L(scomplex, MKL_Complex8, ctrsm) +#else +EIGEN_BLAS_TRSM_L(double, double, dtrsm_) +EIGEN_BLAS_TRSM_L(dcomplex, double, ztrsm_) +EIGEN_BLAS_TRSM_L(float, float, strsm_) +EIGEN_BLAS_TRSM_L(scomplex, float, ctrsm_) +#endif // implements RightSide general * op(triangular)^-1 -#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASPREFIX) \ +#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASFUNC) \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ -struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \ +struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,1> \ { \ enum { \ IsLower = (Mode&Lower) == Lower, \ @@ -104,8 +112,10 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag static void run( \ Index size, Index otherSize, \ const EIGTYPE* _tri, Index triStride, \ - EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ + EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ { \ + EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \ + eigen_assert(otherIncr == 1); \ BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \ char side = 'R', uplo, diag='N', transa; \ /* Set alpha_ */ \ @@ -133,16 +143,22 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag } \ if (IsUnitDiag) diag='U'; \ /* call ?trsm*/ \ - BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ + BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ /*std::cout << "TRMS_L specialization!\n";*/ \ } \ }; -EIGEN_BLAS_TRSM_R(double, double, d) -EIGEN_BLAS_TRSM_R(dcomplex, double, z) -EIGEN_BLAS_TRSM_R(float, float, s) -EIGEN_BLAS_TRSM_R(scomplex, float, c) - +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRSM_R(double, double, dtrsm) +EIGEN_BLAS_TRSM_R(dcomplex, MKL_Complex16, ztrsm) +EIGEN_BLAS_TRSM_R(float, float, strsm) +EIGEN_BLAS_TRSM_R(scomplex, MKL_Complex8, ctrsm) +#else +EIGEN_BLAS_TRSM_R(double, double, dtrsm_) +EIGEN_BLAS_TRSM_R(dcomplex, double, ztrsm_) +EIGEN_BLAS_TRSM_R(float, float, strsm_) +EIGEN_BLAS_TRSM_R(scomplex, float, ctrsm_) +#endif } // end namespace internal |