aboutsummaryrefslogtreecommitdiff
path: root/Eigen/src/Core/SolveTriangular.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/SolveTriangular.h')
-rw-r--r--Eigen/src/Core/SolveTriangular.h35
1 files changed, 19 insertions, 16 deletions
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index 049890b25..dfbf99523 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -10,7 +10,7 @@
#ifndef EIGEN_SOLVETRIANGULAR_H
#define EIGEN_SOLVETRIANGULAR_H
-namespace Eigen {
+namespace Eigen {
namespace internal {
@@ -19,7 +19,7 @@ namespace internal {
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector;
-template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
+template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix;
// small helper struct extracting some traits on the underlying solver operation
@@ -54,7 +54,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ExtractType ActualLhsType;
typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
@@ -64,7 +64,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(),
(useRhsDirectly ? rhs.data() : 0));
-
+
if(!useRhsDirectly)
MappedRhs(actualRhs,rhs.size()) = rhs;
@@ -85,7 +85,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
typedef blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
@@ -98,8 +98,8 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false);
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
- (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
- ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
+ (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor, Rhs::InnerStrideAtCompileTime>
+ ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.innerStride(), rhs.outerStride(), blocking);
}
};
@@ -118,7 +118,7 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> {
DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1,
StartIndex = IsLower ? 0 : DiagIndex+1
};
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
if (LoopIndex>0)
rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose()
@@ -133,22 +133,22 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> {
template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> {
- static void run(const Lhs&, Rhs&) {}
+ static EIGEN_DEVICE_FUNC void run(const Lhs&, Rhs&) {}
};
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
- static void run(const Lhs& lhs, Rhs& rhs)
+ static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
Transpose<const Lhs> trLhs(lhs);
Transpose<Rhs> trRhs(rhs);
-
+
triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>,
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs);
@@ -164,11 +164,14 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename MatrixType, unsigned int Mode>
template<int Side, typename OtherDerived>
-void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
+EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
{
OtherDerived& other = _other.const_cast_derived();
eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) );
- eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
+ eigen_assert((!(int(Mode) & int(ZeroDiag))) && bool(int(Mode) & (int(Upper) | int(Lower))));
+ // If solving for a 0x0 matrix, nothing to do, simply return.
+ if (derived().cols() == 0)
+ return;
enum { copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime!=1};
typedef typename internal::conditional<copy,
@@ -210,8 +213,8 @@ template<int Side, typename TriangularType, typename Rhs> struct triangular_solv
: m_triangularMatrix(tri), m_rhs(rhs)
{}
- inline Index rows() const { return m_rhs.rows(); }
- inline Index cols() const { return m_rhs.cols(); }
+ inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_rhs.rows(); }
+ inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{