diff options
Diffstat (limited to 'internal/ceres/dense_normal_cholesky_solver.cc')
-rw-r--r-- | internal/ceres/dense_normal_cholesky_solver.cc | 69 |
1 files changed, 67 insertions, 2 deletions
diff --git a/internal/ceres/dense_normal_cholesky_solver.cc b/internal/ceres/dense_normal_cholesky_solver.cc index 96f5511..fbf3cbe 100644 --- a/internal/ceres/dense_normal_cholesky_solver.cc +++ b/internal/ceres/dense_normal_cholesky_solver.cc @@ -33,9 +33,11 @@ #include <cstddef> #include "Eigen/Dense" +#include "ceres/blas.h" #include "ceres/dense_sparse_matrix.h" #include "ceres/internal/eigen.h" #include "ceres/internal/scoped_ptr.h" +#include "ceres/lapack.h" #include "ceres/linear_solver.h" #include "ceres/types.h" #include "ceres/wall_time.h" @@ -52,6 +54,18 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl( const double* b, const LinearSolver::PerSolveOptions& per_solve_options, double* x) { + if (options_.dense_linear_algebra_library_type == EIGEN) { + return SolveUsingEigen(A, b, per_solve_options, x); + } else { + return SolveUsingLAPACK(A, b, per_solve_options, x); + } +} + +LinearSolver::Summary DenseNormalCholeskySolver::SolveUsingEigen( + DenseSparseMatrix* A, + const double* b, + const LinearSolver::PerSolveOptions& per_solve_options, + double* x) { EventLogger event_logger("DenseNormalCholeskySolver::Solve"); const int num_rows = A->num_rows(); @@ -62,6 +76,7 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl( lhs.setZero(); event_logger.AddEvent("Setup"); + // lhs += A'A // // Using rankUpdate instead of GEMM, exposes the fact that its the @@ -76,16 +91,66 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl( ConstVectorRef D(per_solve_options.D, num_cols); lhs += D.array().square().matrix().asDiagonal(); } + event_logger.AddEvent("Product"); LinearSolver::Summary summary; summary.num_iterations = 1; summary.termination_type = TOLERANCE; VectorRef(x, num_cols) = - lhs.selfadjointView<Eigen::Upper>().ldlt().solve(rhs); + lhs.selfadjointView<Eigen::Upper>().llt().solve(rhs); event_logger.AddEvent("Solve"); - return summary; } +LinearSolver::Summary DenseNormalCholeskySolver::SolveUsingLAPACK( + DenseSparseMatrix* A, + const double* b, + const LinearSolver::PerSolveOptions& per_solve_options, + double* x) { + EventLogger event_logger("DenseNormalCholeskySolver::Solve"); + + if (per_solve_options.D != NULL) { + // Temporarily append a diagonal block to the A matrix, but undo + // it before returning the matrix to the user. + A->AppendDiagonal(per_solve_options.D); + } + + const int num_cols = A->num_cols(); + Matrix lhs(num_cols, num_cols); + event_logger.AddEvent("Setup"); + + // lhs = A'A + // + // Note: This is a bit delicate, it assumes that the stride on this + // matrix is the same as the number of rows. + BLAS::SymmetricRankKUpdate(A->num_rows(), + num_cols, + A->values(), + true, + 1.0, + 0.0, + lhs.data()); + + if (per_solve_options.D != NULL) { + // Undo the modifications to the matrix A. + A->RemoveDiagonal(); + } + + // TODO(sameeragarwal): Replace this with a gemv call for true blasness. + // rhs = A'b + VectorRef(x, num_cols) = + A->matrix().transpose() * ConstVectorRef(b, A->num_rows()); + event_logger.AddEvent("Product"); + + const int info = LAPACK::SolveInPlaceUsingCholesky(num_cols, lhs.data(), x); + event_logger.AddEvent("Solve"); + + LinearSolver::Summary summary; + summary.num_iterations = 1; + summary.termination_type = info == 0 ? TOLERANCE : FAILURE; + + event_logger.AddEvent("TearDown"); + return summary; +} } // namespace internal } // namespace ceres |