diff options
Diffstat (limited to 'Eigen/src/QR/CompleteOrthogonalDecomposition.h')
-rw-r--r-- | Eigen/src/QR/CompleteOrthogonalDecomposition.h | 127 |
1 files changed, 100 insertions, 27 deletions
diff --git a/Eigen/src/QR/CompleteOrthogonalDecomposition.h b/Eigen/src/QR/CompleteOrthogonalDecomposition.h index 34c637b70..486d3373a 100644 --- a/Eigen/src/QR/CompleteOrthogonalDecomposition.h +++ b/Eigen/src/QR/CompleteOrthogonalDecomposition.h @@ -16,6 +16,9 @@ namespace internal { template <typename _MatrixType> struct traits<CompleteOrthogonalDecomposition<_MatrixType> > : traits<_MatrixType> { + typedef MatrixXpr XprKind; + typedef SolverStorage StorageKind; + typedef int StorageIndex; enum { Flags = 0 }; }; @@ -44,19 +47,21 @@ struct traits<CompleteOrthogonalDecomposition<_MatrixType> > * * \sa MatrixBase::completeOrthogonalDecomposition() */ -template <typename _MatrixType> -class CompleteOrthogonalDecomposition { +template <typename _MatrixType> class CompleteOrthogonalDecomposition + : public SolverBase<CompleteOrthogonalDecomposition<_MatrixType> > +{ public: typedef _MatrixType MatrixType; + typedef SolverBase<CompleteOrthogonalDecomposition> Base; + + template<typename Derived> + friend struct internal::solve_assertion; + + EIGEN_GENERIC_PUBLIC_INTERFACE(CompleteOrthogonalDecomposition) enum { - RowsAtCompileTime = MatrixType::RowsAtCompileTime, - ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; - typedef typename MatrixType::Scalar Scalar; - typedef typename MatrixType::RealScalar RealScalar; - typedef typename MatrixType::StorageIndex StorageIndex; typedef typename internal::plain_diag_type<MatrixType>::type HCoeffsType; typedef PermutationMatrix<ColsAtCompileTime, MaxColsAtCompileTime> PermutationType; @@ -131,9 +136,9 @@ class CompleteOrthogonalDecomposition { m_temp(matrix.cols()) { computeInPlace(); - } - + } + #ifdef EIGEN_PARSED_BY_DOXYGEN /** This method computes the minimum-norm solution X to a least squares * problem \f[\mathrm{minimize} \|A X - B\|, \f] where \b A is the matrix of * which \c *this is the complete orthogonal decomposition. @@ -145,11 +150,8 @@ class CompleteOrthogonalDecomposition { */ template <typename Rhs> inline const Solve<CompleteOrthogonalDecomposition, Rhs> solve( - const MatrixBase<Rhs>& b) const { - eigen_assert(m_cpqr.m_isInitialized && - "CompleteOrthogonalDecomposition is not initialized."); - return Solve<CompleteOrthogonalDecomposition, Rhs>(*this, b.derived()); - } + const MatrixBase<Rhs>& b) const; + #endif HouseholderSequenceType householderQ(void) const; HouseholderSequenceType matrixQ(void) const { return m_cpqr.householderQ(); } @@ -158,8 +160,8 @@ class CompleteOrthogonalDecomposition { */ MatrixType matrixZ() const { MatrixType Z = MatrixType::Identity(m_cpqr.cols(), m_cpqr.cols()); - applyZAdjointOnTheLeftInPlace(Z); - return Z.adjoint(); + applyZOnTheLeftInPlace<false>(Z); + return Z; } /** \returns a reference to the matrix where the complete orthogonal @@ -275,6 +277,7 @@ class CompleteOrthogonalDecomposition { */ inline const Inverse<CompleteOrthogonalDecomposition> pseudoInverse() const { + eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized."); return Inverse<CompleteOrthogonalDecomposition>(*this); } @@ -353,7 +356,7 @@ class CompleteOrthogonalDecomposition { inline RealScalar maxPivot() const { return m_cpqr.maxPivot(); } /** \brief Reports whether the complete orthogonal decomposition was - * succesful. + * successful. * * \note This function always returns \c Success. It is provided for * compatibility @@ -367,7 +370,10 @@ class CompleteOrthogonalDecomposition { #ifndef EIGEN_PARSED_BY_DOXYGEN template <typename RhsType, typename DstType> - EIGEN_DEVICE_FUNC void _solve_impl(const RhsType& rhs, DstType& dst) const; + void _solve_impl(const RhsType& rhs, DstType& dst) const; + + template<bool Conjugate, typename RhsType, typename DstType> + void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const; #endif protected: @@ -375,8 +381,22 @@ class CompleteOrthogonalDecomposition { EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); } + template<bool Transpose_, typename Rhs> + void _check_solve_assertion(const Rhs& b) const { + EIGEN_ONLY_USED_FOR_DEBUG(b); + eigen_assert(m_cpqr.m_isInitialized && "CompleteOrthogonalDecomposition is not initialized."); + eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "CompleteOrthogonalDecomposition::solve(): invalid number of rows of the right hand side matrix b"); + } + void computeInPlace(); + /** Overwrites \b rhs with \f$ \mathbf{Z} * \mathbf{rhs} \f$ or + * \f$ \mathbf{\overline Z} * \mathbf{rhs} \f$ if \c Conjugate + * is set to \c true. + */ + template <bool Conjugate, typename Rhs> + void applyZOnTheLeftInPlace(Rhs& rhs) const; + /** Overwrites \b rhs with \f$ \mathbf{Z}^* * \mathbf{rhs} \f$. */ template <typename Rhs> @@ -452,7 +472,7 @@ void CompleteOrthogonalDecomposition<MatrixType>::computeInPlace() // Apply Z(k) to the first k rows of X_k m_cpqr.m_qr.topRightCorner(k, cols - rank + 1) .applyHouseholderOnTheRight( - m_cpqr.m_qr.row(k).tail(cols - rank).transpose(), m_zCoeffs(k), + m_cpqr.m_qr.row(k).tail(cols - rank).adjoint(), m_zCoeffs(k), &m_temp(0)); } if (k != rank - 1) { @@ -465,13 +485,35 @@ void CompleteOrthogonalDecomposition<MatrixType>::computeInPlace() } template <typename MatrixType> +template <bool Conjugate, typename Rhs> +void CompleteOrthogonalDecomposition<MatrixType>::applyZOnTheLeftInPlace( + Rhs& rhs) const { + const Index cols = this->cols(); + const Index nrhs = rhs.cols(); + const Index rank = this->rank(); + Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); + for (Index k = rank-1; k >= 0; --k) { + if (k != rank - 1) { + rhs.row(k).swap(rhs.row(rank - 1)); + } + rhs.middleRows(rank - 1, cols - rank + 1) + .applyHouseholderOnTheLeft( + matrixQTZ().row(k).tail(cols - rank).transpose().template conjugateIf<!Conjugate>(), zCoeffs().template conjugateIf<Conjugate>()(k), + &temp(0)); + if (k != rank - 1) { + rhs.row(k).swap(rhs.row(rank - 1)); + } + } +} + +template <typename MatrixType> template <typename Rhs> void CompleteOrthogonalDecomposition<MatrixType>::applyZAdjointOnTheLeftInPlace( Rhs& rhs) const { const Index cols = this->cols(); const Index nrhs = rhs.cols(); const Index rank = this->rank(); - Matrix<typename MatrixType::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); + Matrix<typename Rhs::Scalar, Dynamic, 1> temp((std::max)(cols, nrhs)); for (Index k = 0; k < rank; ++k) { if (k != rank - 1) { rhs.row(k).swap(rhs.row(rank - 1)); @@ -491,8 +533,6 @@ template <typename _MatrixType> template <typename RhsType, typename DstType> void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl( const RhsType& rhs, DstType& dst) const { - eigen_assert(rhs.rows() == this->rows()); - const Index rank = this->rank(); if (rank == 0) { dst.setZero(); @@ -500,11 +540,8 @@ void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl( } // Compute c = Q^* * rhs - // Note that the matrix Q = H_0^* H_1^*... so its inverse is - // Q^* = (H_0 H_1 ...)^T typename RhsType::PlainObject c(rhs); - c.applyOnTheLeft( - householderSequence(matrixQTZ(), hCoeffs()).setLength(rank).transpose()); + c.applyOnTheLeft(matrixQ().setLength(rank).adjoint()); // Solve T z = c(1:rank, :) dst.topRows(rank) = matrixT() @@ -523,10 +560,45 @@ void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl( // Undo permutation to get x = P^{-1} * y. dst = colsPermutation() * dst; } + +template<typename _MatrixType> +template<bool Conjugate, typename RhsType, typename DstType> +void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const +{ + const Index rank = this->rank(); + + if (rank == 0) { + dst.setZero(); + return; + } + + typename RhsType::PlainObject c(colsPermutation().transpose()*rhs); + + if (rank < cols()) { + applyZOnTheLeftInPlace<!Conjugate>(c); + } + + matrixT().topLeftCorner(rank, rank) + .template triangularView<Upper>() + .transpose().template conjugateIf<Conjugate>() + .solveInPlace(c.topRows(rank)); + + dst.topRows(rank) = c.topRows(rank); + dst.bottomRows(rows()-rank).setZero(); + + dst.applyOnTheLeft(householderQ().setLength(rank).template conjugateIf<!Conjugate>() ); +} #endif namespace internal { +template<typename MatrixType> +struct traits<Inverse<CompleteOrthogonalDecomposition<MatrixType> > > + : traits<typename Transpose<typename MatrixType::PlainObject>::PlainObject> +{ + enum { Flags = 0 }; +}; + template<typename DstXprType, typename MatrixType> struct Assignment<DstXprType, Inverse<CompleteOrthogonalDecomposition<MatrixType> >, internal::assign_op<typename DstXprType::Scalar,typename CompleteOrthogonalDecomposition<MatrixType>::Scalar>, Dense2Dense> { @@ -534,7 +606,8 @@ struct Assignment<DstXprType, Inverse<CompleteOrthogonalDecomposition<MatrixType typedef Inverse<CodType> SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename CodType::Scalar> &) { - dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.rows())); + typedef Matrix<typename CodType::Scalar, CodType::RowsAtCompileTime, CodType::RowsAtCompileTime, 0, CodType::MaxRowsAtCompileTime, CodType::MaxRowsAtCompileTime> IdentityMatrixType; + dst = src.nestedExpression().solve(IdentityMatrixType::Identity(src.cols(), src.cols())); } }; |