aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
diff options
context:
space:
mode:
authorYi Kong <yikong@google.com>2022-02-25 15:53:09 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2022-02-25 15:53:09 +0000
commit10f298fc4175c1b8537c674f654a070c871960e5 (patch)
treefb979fb4cf4f8052c8cc66b1ec9516d91fcd859b /Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
parent892aea0d75825c43d5b630e2060622cbba23694c (diff)
parent79df15ea886a5fc1b85de433f9b3518c68934bae (diff)
downloadeigen-10f298fc4175c1b8537c674f654a070c871960e5.tar.gz
Original change: https://android-review.googlesource.com/c/platform/external/eigen/+/1999079 Change-Id: I0c5108390c595f0d39af8797875f2b88accb7b56
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h25
1 files changed, 17 insertions, 8 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
index 7a3bdbf20..71abf4013 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
@@ -46,25 +46,27 @@ namespace internal {
// gemm specialization
-#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASPREFIX) \
+#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASFUNC) \
template< \
typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
-struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor> \
+struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
{ \
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
\
static void run(Index rows, Index cols, Index depth, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
- EIGTYPE* res, Index resStride, \
+ EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, \
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, \
GemmParallelInfo<Index>* /*info = 0*/) \
{ \
using std::conj; \
\
+ EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
+ eigen_assert(resIncr == 1); \
char transa, transb; \
BlasIndex m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
@@ -100,13 +102,20 @@ static void run(Index rows, Index cols, Index depth, \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
- BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
}};
-GEMM_SPECIALIZATION(double, d, double, d)
-GEMM_SPECIALIZATION(float, f, float, s)
-GEMM_SPECIALIZATION(dcomplex, cd, double, z)
-GEMM_SPECIALIZATION(scomplex, cf, float, c)
+#ifdef EIGEN_USE_MKL
+GEMM_SPECIALIZATION(double, d, double, dgemm)
+GEMM_SPECIALIZATION(float, f, float, sgemm)
+GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, zgemm)
+GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, cgemm)
+#else
+GEMM_SPECIALIZATION(double, d, double, dgemm_)
+GEMM_SPECIALIZATION(float, f, float, sgemm_)
+GEMM_SPECIALIZATION(dcomplex, cd, double, zgemm_)
+GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_)
+#endif
} // end namespase internal