aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/products/GeneralMatrixMatrix.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/products/GeneralMatrixMatrix.h')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h74
1 files changed, 50 insertions, 24 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 41cfb0e03..caa65fccc 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -20,8 +20,9 @@ template<typename _LhsScalar, typename _RhsScalar> class level3_blocking;
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
{
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
@@ -30,7 +31,7 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
Index rows, Index cols, Index depth,
const LhsScalar* lhs, Index lhsStride,
const RhsScalar* rhs, Index rhsStride,
- ResScalar* res, Index resStride,
+ ResScalar* res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<RhsScalar,LhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
@@ -39,8 +40,8 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
general_matrix_matrix_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
- ColMajor>
- ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
+ ColMajor,ResInnerStride>
+ ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
}
};
@@ -49,8 +50,9 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
template<
typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
- typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
-struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
+ typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
+ int ResInnerStride>
+struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
{
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
@@ -59,23 +61,23 @@ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScala
static void run(Index rows, Index cols, Index depth,
const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride,
- ResScalar* _res, Index resStride,
+ ResScalar* _res, Index resIncr, Index resStride,
ResScalar alpha,
level3_blocking<LhsScalar,RhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0)
{
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
- typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
- LhsMapper lhs(_lhs,lhsStride);
- RhsMapper rhs(_rhs,rhsStride);
- ResMapper res(_res, resStride);
+ typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
+ LhsMapper lhs(_lhs, lhsStride);
+ RhsMapper rhs(_rhs, rhsStride);
+ ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
Index nc = (std::min)(cols,blocking.nc()); // cache block size along the N direction
- gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
+ gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
@@ -108,7 +110,7 @@ static void run(Index rows, Index cols, Index depth,
// i.e., we test that info[tid].users equals 0.
// Then, we set info[tid].users to the number of threads to mark that all other threads are going to use it.
while(info[tid].users!=0) {}
- info[tid].users += threads;
+ info[tid].users = threads;
pack_lhs(blockA+info[tid].lhs_start*actual_kc, lhs.getSubMapper(info[tid].lhs_start,k), actual_kc, info[tid].lhs_length);
@@ -146,6 +148,9 @@ static void run(Index rows, Index cols, Index depth,
// Release all the sub blocks A'_i of A' for the current thread,
// i.e., we simply decrement the number of users by 1
for(Index i=0; i<threads; ++i)
+#if !EIGEN_HAS_CXX11_ATOMIC
+ #pragma omp atomic
+#endif
info[i].users -= 1;
}
}
@@ -225,7 +230,7 @@ struct gemm_functor
Gemm::run(rows, cols, m_lhs.cols(),
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
- (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
+ (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(),
m_actualAlpha, m_blocking, info);
}
@@ -426,8 +431,14 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
template<typename Dst>
static void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::evalTo(dst, lhs, rhs);
+ // See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=404 for a discussion and helper program
+ // to determine the following heuristic.
+ // EIGEN_GEMM_TO_COEFFBASED_THRESHOLD is typically defined to 20 in GeneralProduct.h,
+ // unless it has been specialized by the user or for a given architecture.
+ // Note that the condition rhs.rows()>0 was required because lazy product is (was?) not happy with empty inputs.
+ // I'm not sure it is still required.
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::assign_op<typename Dst::Scalar,Scalar>());
else
{
dst.setZero();
@@ -438,8 +449,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
template<typename Dst>
static void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::addTo(dst, lhs, rhs);
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar,Scalar>());
else
scaleAndAddTo(dst,lhs, rhs, Scalar(1));
}
@@ -447,8 +458,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
template<typename Dst>
static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{
- if((rhs.rows()+dst.rows()+dst.cols())<20 && rhs.rows()>0)
- lazyproduct::subTo(dst, lhs, rhs);
+ if((rhs.rows()+dst.rows()+dst.cols())<EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows()>0)
+ lazyproduct::eval_dynamic(dst, lhs, rhs, internal::sub_assign_op<typename Dst::Scalar,Scalar>());
else
scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
}
@@ -460,11 +471,25 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
return;
+ if (dst.cols() == 1)
+ {
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
+ typename Dest::ColXpr dst_vec(dst.col(0));
+ return internal::generic_product_impl<Lhs,typename Rhs::ConstColXpr,DenseShape,DenseShape,GemvProduct>
+ ::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha);
+ }
+ else if (dst.rows() == 1)
+ {
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
+ typename Dest::RowXpr dst_vec(dst.row(0));
+ return internal::generic_product_impl<typename Lhs::ConstRowXpr,Rhs,DenseShape,DenseShape,GemvProduct>
+ ::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);
+ }
+
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
- * RhsBlasTraits::extractScalarFactor(a_rhs);
+ Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
@@ -475,7 +500,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
Index,
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
- (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
+ (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,
+ Dest::InnerStrideAtCompileTime>,
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);