diff options
Diffstat (limited to 'internal/ceres/dense_qr_solver.cc')
-rw-r--r-- | internal/ceres/dense_qr_solver.cc | 81 |
1 files changed, 78 insertions, 3 deletions
diff --git a/internal/ceres/dense_qr_solver.cc b/internal/ceres/dense_qr_solver.cc index 1fb9709..d76d58b 100644 --- a/internal/ceres/dense_qr_solver.cc +++ b/internal/ceres/dense_qr_solver.cc @@ -30,12 +30,13 @@ #include "ceres/dense_qr_solver.h" -#include <cstddef> +#include <cstddef> #include "Eigen/Dense" #include "ceres/dense_sparse_matrix.h" #include "ceres/internal/eigen.h" #include "ceres/internal/scoped_ptr.h" +#include "ceres/lapack.h" #include "ceres/linear_solver.h" #include "ceres/types.h" #include "ceres/wall_time.h" @@ -44,13 +45,87 @@ namespace ceres { namespace internal { DenseQRSolver::DenseQRSolver(const LinearSolver::Options& options) - : options_(options) {} + : options_(options) { + work_.resize(1); +} LinearSolver::Summary DenseQRSolver::SolveImpl( DenseSparseMatrix* A, const double* b, const LinearSolver::PerSolveOptions& per_solve_options, double* x) { + if (options_.dense_linear_algebra_library_type == EIGEN) { + return SolveUsingEigen(A, b, per_solve_options, x); + } else { + return SolveUsingLAPACK(A, b, per_solve_options, x); + } +} +LinearSolver::Summary DenseQRSolver::SolveUsingLAPACK( + DenseSparseMatrix* A, + const double* b, + const LinearSolver::PerSolveOptions& per_solve_options, + double* x) { + EventLogger event_logger("DenseQRSolver::Solve"); + + const int num_rows = A->num_rows(); + const int num_cols = A->num_cols(); + + if (per_solve_options.D != NULL) { + // Temporarily append a diagonal block to the A matrix, but undo + // it before returning the matrix to the user. + A->AppendDiagonal(per_solve_options.D); + } + + // TODO(sameeragarwal): Since we are copying anyways, the diagonal + // can be appended to the matrix instead of doing it on A. + lhs_ = A->matrix(); + + if (per_solve_options.D != NULL) { + // Undo the modifications to the matrix A. + A->RemoveDiagonal(); + } + + // rhs = [b;0] to account for the additional rows in the lhs. + if (rhs_.rows() != lhs_.rows()) { + rhs_.resize(lhs_.rows()); + } + rhs_.setZero(); + rhs_.head(num_rows) = ConstVectorRef(b, num_rows); + + if (work_.rows() == 1) { + const int work_size = + LAPACK::EstimateWorkSizeForQR(lhs_.rows(), lhs_.cols()); + VLOG(3) << "Working memory for Dense QR factorization: " + << work_size * sizeof(double); + work_.resize(work_size); + } + + const int info = LAPACK::SolveUsingQR(lhs_.rows(), + lhs_.cols(), + lhs_.data(), + work_.rows(), + work_.data(), + rhs_.data()); + event_logger.AddEvent("Solve"); + + LinearSolver::Summary summary; + summary.num_iterations = 1; + if (info == 0) { + VectorRef(x, num_cols) = rhs_.head(num_cols); + summary.termination_type = TOLERANCE; + } else { + summary.termination_type = FAILURE; + } + + event_logger.AddEvent("TearDown"); + return summary; +} + +LinearSolver::Summary DenseQRSolver::SolveUsingEigen( + DenseSparseMatrix* A, + const double* b, + const LinearSolver::PerSolveOptions& per_solve_options, + double* x) { EventLogger event_logger("DenseQRSolver::Solve"); const int num_rows = A->num_rows(); @@ -73,7 +148,7 @@ LinearSolver::Summary DenseQRSolver::SolveImpl( event_logger.AddEvent("Setup"); // Solve the system. - VectorRef(x, num_cols) = A->matrix().colPivHouseholderQr().solve(rhs_); + VectorRef(x, num_cols) = A->matrix().householderQr().solve(rhs_); event_logger.AddEvent("Solve"); if (per_solve_options.D != NULL) { |