aboutsummaryrefslogtreecommitdiff
path: root/internal/ceres/dense_normal_cholesky_solver.cc
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ceres/dense_normal_cholesky_solver.cc')
-rw-r--r--internal/ceres/dense_normal_cholesky_solver.cc69
1 files changed, 67 insertions, 2 deletions
diff --git a/internal/ceres/dense_normal_cholesky_solver.cc b/internal/ceres/dense_normal_cholesky_solver.cc
index 96f5511..fbf3cbe 100644
--- a/internal/ceres/dense_normal_cholesky_solver.cc
+++ b/internal/ceres/dense_normal_cholesky_solver.cc
@@ -33,9 +33,11 @@
#include <cstddef>
#include "Eigen/Dense"
+#include "ceres/blas.h"
#include "ceres/dense_sparse_matrix.h"
#include "ceres/internal/eigen.h"
#include "ceres/internal/scoped_ptr.h"
+#include "ceres/lapack.h"
#include "ceres/linear_solver.h"
#include "ceres/types.h"
#include "ceres/wall_time.h"
@@ -52,6 +54,18 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl(
const double* b,
const LinearSolver::PerSolveOptions& per_solve_options,
double* x) {
+ if (options_.dense_linear_algebra_library_type == EIGEN) {
+ return SolveUsingEigen(A, b, per_solve_options, x);
+ } else {
+ return SolveUsingLAPACK(A, b, per_solve_options, x);
+ }
+}
+
+LinearSolver::Summary DenseNormalCholeskySolver::SolveUsingEigen(
+ DenseSparseMatrix* A,
+ const double* b,
+ const LinearSolver::PerSolveOptions& per_solve_options,
+ double* x) {
EventLogger event_logger("DenseNormalCholeskySolver::Solve");
const int num_rows = A->num_rows();
@@ -62,6 +76,7 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl(
lhs.setZero();
event_logger.AddEvent("Setup");
+
// lhs += A'A
//
// Using rankUpdate instead of GEMM, exposes the fact that its the
@@ -76,16 +91,66 @@ LinearSolver::Summary DenseNormalCholeskySolver::SolveImpl(
ConstVectorRef D(per_solve_options.D, num_cols);
lhs += D.array().square().matrix().asDiagonal();
}
+ event_logger.AddEvent("Product");
LinearSolver::Summary summary;
summary.num_iterations = 1;
summary.termination_type = TOLERANCE;
VectorRef(x, num_cols) =
- lhs.selfadjointView<Eigen::Upper>().ldlt().solve(rhs);
+ lhs.selfadjointView<Eigen::Upper>().llt().solve(rhs);
event_logger.AddEvent("Solve");
-
return summary;
}
+LinearSolver::Summary DenseNormalCholeskySolver::SolveUsingLAPACK(
+ DenseSparseMatrix* A,
+ const double* b,
+ const LinearSolver::PerSolveOptions& per_solve_options,
+ double* x) {
+ EventLogger event_logger("DenseNormalCholeskySolver::Solve");
+
+ if (per_solve_options.D != NULL) {
+ // Temporarily append a diagonal block to the A matrix, but undo
+ // it before returning the matrix to the user.
+ A->AppendDiagonal(per_solve_options.D);
+ }
+
+ const int num_cols = A->num_cols();
+ Matrix lhs(num_cols, num_cols);
+ event_logger.AddEvent("Setup");
+
+ // lhs = A'A
+ //
+ // Note: This is a bit delicate, it assumes that the stride on this
+ // matrix is the same as the number of rows.
+ BLAS::SymmetricRankKUpdate(A->num_rows(),
+ num_cols,
+ A->values(),
+ true,
+ 1.0,
+ 0.0,
+ lhs.data());
+
+ if (per_solve_options.D != NULL) {
+ // Undo the modifications to the matrix A.
+ A->RemoveDiagonal();
+ }
+
+ // TODO(sameeragarwal): Replace this with a gemv call for true blasness.
+ // rhs = A'b
+ VectorRef(x, num_cols) =
+ A->matrix().transpose() * ConstVectorRef(b, A->num_rows());
+ event_logger.AddEvent("Product");
+
+ const int info = LAPACK::SolveInPlaceUsingCholesky(num_cols, lhs.data(), x);
+ event_logger.AddEvent("Solve");
+
+ LinearSolver::Summary summary;
+ summary.num_iterations = 1;
+ summary.termination_type = info == 0 ? TOLERANCE : FAILURE;
+
+ event_logger.AddEvent("TearDown");
+ return summary;
+}
} // namespace internal
} // namespace ceres