diff options
Diffstat (limited to 'internal/ceres/cgnr_solver.cc')
-rw-r--r-- | internal/ceres/cgnr_solver.cc | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/internal/ceres/cgnr_solver.cc b/internal/ceres/cgnr_solver.cc index 9b8f980..88e61d9 100644 --- a/internal/ceres/cgnr_solver.cc +++ b/internal/ceres/cgnr_solver.cc @@ -33,6 +33,7 @@ #include "ceres/block_jacobi_preconditioner.h" #include "ceres/cgnr_linear_operator.h" #include "ceres/conjugate_gradients_solver.h" +#include "ceres/internal/eigen.h" #include "ceres/linear_solver.h" #include "ceres/wall_time.h" #include "glog/logging.h" @@ -43,6 +44,10 @@ namespace internal { CgnrSolver::CgnrSolver(const LinearSolver::Options& options) : options_(options), preconditioner_(NULL) { + if (options_.preconditioner_type != JACOBI && + options_.preconditioner_type != IDENTITY) { + LOG(FATAL) << "CGNR only supports IDENTITY and JACOBI preconditioners."; + } } LinearSolver::Summary CgnrSolver::SolveImpl( @@ -53,9 +58,9 @@ LinearSolver::Summary CgnrSolver::SolveImpl( EventLogger event_logger("CgnrSolver::Solve"); // Form z = Atb. - scoped_array<double> z(new double[A->num_cols()]); - std::fill(z.get(), z.get() + A->num_cols(), 0.0); - A->LeftMultiply(b, z.get()); + Vector z(A->num_cols()); + z.setZero(); + A->LeftMultiply(b, z.data()); // Precondition if necessary. LinearSolver::PerSolveOptions cg_per_solve_options = per_solve_options; @@ -65,20 +70,17 @@ LinearSolver::Summary CgnrSolver::SolveImpl( } preconditioner_->Update(*A, per_solve_options.D); cg_per_solve_options.preconditioner = preconditioner_.get(); - } else if (options_.preconditioner_type != IDENTITY) { - LOG(FATAL) << "CGNR only supports IDENTITY and JACOBI preconditioners."; } // Solve (AtA + DtD)x = z (= Atb). - std::fill(x, x + A->num_cols(), 0.0); + VectorRef(x, A->num_cols()).setZero(); CgnrLinearOperator lhs(*A, per_solve_options.D); event_logger.AddEvent("Setup"); ConjugateGradientsSolver conjugate_gradient_solver(options_); LinearSolver::Summary summary = - conjugate_gradient_solver.Solve(&lhs, z.get(), cg_per_solve_options, x); + conjugate_gradient_solver.Solve(&lhs, z.data(), cg_per_solve_options, x); event_logger.AddEvent("Solve"); - return summary; } |