diff options
Diffstat (limited to 'Eigen/src/Core/SolverBase.h')
-rw-r--r-- | Eigen/src/Core/SolverBase.h | 44 |
1 files changed, 41 insertions, 3 deletions
diff --git a/Eigen/src/Core/SolverBase.h b/Eigen/src/Core/SolverBase.h index 8a4adc229..501461042 100644 --- a/Eigen/src/Core/SolverBase.h +++ b/Eigen/src/Core/SolverBase.h @@ -14,8 +14,35 @@ namespace Eigen { namespace internal { +template<typename Derived> +struct solve_assertion { + template<bool Transpose_, typename Rhs> + static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); } +}; + +template<typename Derived> +struct solve_assertion<Transpose<Derived> > +{ + typedef Transpose<Derived> type; + + template<bool Transpose_, typename Rhs> + static void run(const type& transpose, const Rhs& b) + { + internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b); + } +}; +template<typename Scalar, typename Derived> +struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > > +{ + typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type; + template<bool Transpose_, typename Rhs> + static void run(const type& adjoint, const Rhs& b) + { + internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b); + } +}; } // end namespace internal /** \class SolverBase @@ -35,7 +62,7 @@ namespace internal { * * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors. * - * \sa class PartialPivLU, class FullPivLU + * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase */ template<typename Derived> class SolverBase : public EigenBase<Derived> @@ -46,6 +73,9 @@ class SolverBase : public EigenBase<Derived> typedef typename internal::traits<Derived>::Scalar Scalar; typedef Scalar CoeffReturnType; + template<typename Derived_> + friend struct internal::solve_assertion; + enum { RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime, ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime, @@ -56,7 +86,8 @@ class SolverBase : public EigenBase<Derived> MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime, internal::traits<Derived>::MaxColsAtCompileTime>::ret), IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1 - || internal::traits<Derived>::MaxColsAtCompileTime == 1 + || internal::traits<Derived>::MaxColsAtCompileTime == 1, + NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2 }; /** Default constructor */ @@ -74,7 +105,7 @@ class SolverBase : public EigenBase<Derived> inline const Solve<Derived, Rhs> solve(const MatrixBase<Rhs>& b) const { - eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b"); + internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b); return Solve<Derived, Rhs>(derived(), b.derived()); } @@ -112,6 +143,13 @@ class SolverBase : public EigenBase<Derived> } protected: + + template<bool Transpose_, typename Rhs> + void _check_solve_assertion(const Rhs& b) const { + EIGEN_ONLY_USED_FOR_DEBUG(b); + eigen_assert(derived().m_isInitialized && "Solver is not initialized."); + eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b"); + } }; namespace internal { |