diff options
Diffstat (limited to 'Eigen/src/Core/products/TriangularMatrixVector_BLAS.h')
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixVector_BLAS.h | 46 |
1 files changed, 30 insertions, 16 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h b/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h index 07bf26ce5..3d47a2b94 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +++ b/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h @@ -71,7 +71,7 @@ EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex) EIGEN_BLAS_TRMV_SPECIALIZE(scomplex) // implements col-major: res += alpha * op(triangular) * vector -#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \ enum { \ @@ -121,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE, diag = IsUnitDiag ? 'U' : 'N'; \ \ /* call ?TRMV*/ \ - BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ + BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ \ /* Add op(a_tr)rhs into res*/ \ - BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ + BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \ if (size<(std::max)(rows,cols)) { \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ @@ -142,18 +142,25 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE, m = convert_index<BlasIndex>(size); \ n = convert_index<BlasIndex>(cols-size); \ } \ - BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ + BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; -EIGEN_BLAS_TRMV_CM(double, double, d, d) -EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z) -EIGEN_BLAS_TRMV_CM(float, float, f, s) -EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c) +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRMV_CM(double, double, d, d,) +EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z,) +EIGEN_BLAS_TRMV_CM(float, float, f, s,) +EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c,) +#else +EIGEN_BLAS_TRMV_CM(double, double, d, d, _) +EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z, _) +EIGEN_BLAS_TRMV_CM(float, float, f, s, _) +EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c, _) +#endif // implements row-major: res += alpha * op(triangular) * vector -#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ +#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \ enum { \ @@ -203,10 +210,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE, diag = IsUnitDiag ? 'U' : 'N'; \ \ /* call ?TRMV*/ \ - BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ + BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ \ /* Add op(a_tr)rhs into res*/ \ - BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ + BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \ if (size<(std::max)(rows,cols)) { \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ @@ -224,15 +231,22 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE, m = convert_index<BlasIndex>(size); \ n = convert_index<BlasIndex>(cols-size); \ } \ - BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ + BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; -EIGEN_BLAS_TRMV_RM(double, double, d, d) -EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z) -EIGEN_BLAS_TRMV_RM(float, float, f, s) -EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c) +#ifdef EIGEN_USE_MKL +EIGEN_BLAS_TRMV_RM(double, double, d, d,) +EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z,) +EIGEN_BLAS_TRMV_RM(float, float, f, s,) +EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c,) +#else +EIGEN_BLAS_TRMV_RM(double, double, d, d,_) +EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z,_) +EIGEN_BLAS_TRMV_RM(float, float, f, s,_) +EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c,_) +#endif } // end namespase internal |