aboutsummaryrefslogtreecommitdiff
path: root/examples/nist.cc
diff options
context:
space:
mode:
Diffstat (limited to 'examples/nist.cc')
-rw-r--r--examples/nist.cc202
1 files changed, 152 insertions, 50 deletions
diff --git a/examples/nist.cc b/examples/nist.cc
index 440ab5c..1773a0f 100644
--- a/examples/nist.cc
+++ b/examples/nist.cc
@@ -28,29 +28,61 @@
//
// Author: sameeragarwal@google.com (Sameer Agarwal)
//
-// NIST non-linear regression problems solved using Ceres.
+// The National Institute of Standards and Technology has released a
+// set of problems to test non-linear least squares solvers.
//
-// The data was obtained from
-// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml, where more
-// background on these problems can also be found.
+// More information about the background on these problems and
+// suggested evaluation methodology can be found at:
//
-// Currently not all problems are solved successfully. Some of the
-// failures are due to convergence to a local minimum, and some fail
-// because of numerical issues.
+// http://www.itl.nist.gov/div898/strd/nls/nls_info.shtml
//
-// TODO(sameeragarwal): Fix numerical issues so that all the problems
-// converge and then look at convergence to the wrong solution issues.
+// The problem data themselves can be found at
+//
+// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml
+//
+// The problems are divided into three levels of difficulty, Easy,
+// Medium and Hard. For each problem there are two starting guesses,
+// the first one far away from the global minimum and the second
+// closer to it.
+//
+// A problem is considered successfully solved, if every components of
+// the solution matches the globally optimal solution in at least 4
+// digits or more.
+//
+// This dataset was used for an evaluation of Non-linear least squares
+// solvers:
+//
+// P. F. Mondragon & B. Borchers, A Comparison of Nonlinear Regression
+// Codes, Journal of Modern Applied Statistical Methods, 4(1):343-351,
+// 2005.
+//
+// The results from Mondragon & Borchers can be summarized as
+// Excel Gnuplot GaussFit HBN MinPack
+// Average LRE 2.3 4.3 4.0 6.8 4.4
+// Winner 1 5 12 29 12
+//
+// Where the row Winner counts, the number of problems for which the
+// solver had the highest LRE.
+
+// In this file, we implement the same evaluation methodology using
+// Ceres. Currently using Levenberg-Marquard with DENSE_QR, we get
+//
+// Excel Gnuplot GaussFit HBN MinPack Ceres
+// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4
+// Winner 0 0 5 11 2 41
#include <iostream>
+#include <iterator>
#include <fstream>
#include "ceres/ceres.h"
-#include "ceres/split.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "Eigen/Core"
DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear"
"regression examples");
+DEFINE_string(minimizer, "trust_region",
+ "Minimizer type to use, choices are: line_search & trust_region");
DEFINE_string(trust_region_strategy, "levenberg_marquardt",
"Options are: levenberg_marquardt, dogleg");
DEFINE_string(dogleg, "traditional_dogleg",
@@ -60,21 +92,63 @@ DEFINE_string(linear_solver, "dense_qr", "Options are: "
"cgnr");
DEFINE_string(preconditioner, "jacobi", "Options are: "
"identity, jacobi");
+DEFINE_string(line_search, "armijo",
+ "Line search algorithm to use, choices are: armijo and wolfe.");
+DEFINE_string(line_search_direction, "lbfgs",
+ "Line search direction algorithm to use, choices: lbfgs, bfgs");
+DEFINE_int32(max_line_search_iterations, 20,
+ "Maximum number of iterations for each line search.");
+DEFINE_int32(max_line_search_restarts, 10,
+ "Maximum number of restarts of line search direction algorithm.");
+DEFINE_string(line_search_interpolation, "cubic",
+ "Degree of polynomial aproximation in line search, "
+ "choices are: bisection, quadratic & cubic.");
+DEFINE_int32(lbfgs_rank, 20,
+ "Rank of L-BFGS inverse Hessian approximation in line search.");
+DEFINE_bool(approximate_eigenvalue_bfgs_scaling, false,
+ "Use approximate eigenvalue scaling in (L)BFGS line search.");
+DEFINE_double(sufficient_decrease, 1.0e-4,
+ "Line search Armijo sufficient (function) decrease factor.");
+DEFINE_double(sufficient_curvature_decrease, 0.9,
+ "Line search Wolfe sufficient curvature decrease factor.");
DEFINE_int32(num_iterations, 10000, "Number of iterations");
DEFINE_bool(nonmonotonic_steps, false, "Trust region algorithm can use"
" nonmonotic steps");
DEFINE_double(initial_trust_region_radius, 1e4, "Initial trust region radius");
+namespace ceres {
+namespace examples {
+
using Eigen::Dynamic;
using Eigen::RowMajor;
typedef Eigen::Matrix<double, Dynamic, 1> Vector;
typedef Eigen::Matrix<double, Dynamic, Dynamic, RowMajor> Matrix;
+void SplitStringUsingChar(const string& full,
+ const char delim,
+ vector<string>* result) {
+ back_insert_iterator< vector<string> > it(*result);
+
+ const char* p = full.data();
+ const char* end = p + full.size();
+ while (p != end) {
+ if (*p == delim) {
+ ++p;
+ } else {
+ const char* start = p;
+ while (++p != end && *p != delim) {
+ // Skip to the next occurence of the delimiter.
+ }
+ *it++ = string(start, p - start);
+ }
+ }
+}
+
bool GetAndSplitLine(std::ifstream& ifs, std::vector<std::string>* pieces) {
pieces->clear();
char buf[256];
ifs.getline(buf, 256);
- ceres::SplitStringUsing(std::string(buf), " ", pieces);
+ SplitStringUsingChar(std::string(buf), ' ', pieces);
return true;
}
@@ -339,7 +413,7 @@ struct Nelson {
template <typename Model, int num_residuals, int num_parameters>
int RegressionDriver(const std::string& filename,
- const ceres::Solver::Options& options) {
+ const ceres::Solver::Options& options) {
NISTProblem nist_problem(FLAGS_nist_data_dir + filename);
CHECK_EQ(num_residuals, nist_problem.response_size());
CHECK_EQ(num_parameters, nist_problem.num_parameters());
@@ -347,11 +421,12 @@ int RegressionDriver(const std::string& filename,
Matrix predictor = nist_problem.predictor();
Matrix response = nist_problem.response();
Matrix final_parameters = nist_problem.final_parameters();
- std::vector<ceres::Solver::Summary> summaries(nist_problem.num_starts() + 1);
- std::cerr << filename << std::endl;
+
+ printf("%s\n", filename.c_str());
// Each NIST problem comes with multiple starting points, so we
// construct the problem from scratch for each case and solve it.
+ int num_success = 0;
for (int start = 0; start < nist_problem.num_starts(); ++start) {
Matrix initial_parameters = nist_problem.initial_parameters(start);
@@ -365,43 +440,49 @@ int RegressionDriver(const std::string& filename,
initial_parameters.data());
}
- Solve(options, &problem, &summaries[start]);
- }
-
- const double certified_cost = nist_problem.certified_cost();
-
- int num_success = 0;
- const int kMinNumMatchingDigits = 4;
- for (int start = 0; start < nist_problem.num_starts(); ++start) {
- const ceres::Solver::Summary& summary = summaries[start];
-
- int num_matching_digits = 0;
- if (IsSuccessfulTermination(summary.termination_type)
- && summary.final_cost < certified_cost) {
- num_matching_digits = kMinNumMatchingDigits + 1;
- } else {
- num_matching_digits =
- -std::log10(fabs(summary.final_cost - certified_cost) / certified_cost);
+ ceres::Solver::Summary summary;
+ Solve(options, &problem, &summary);
+
+ // Compute the LRE by comparing each component of the solution
+ // with the ground truth, and taking the minimum.
+ Matrix final_parameters = nist_problem.final_parameters();
+ const double kMaxNumSignificantDigits = 11;
+ double log_relative_error = kMaxNumSignificantDigits + 1;
+ for (int i = 0; i < num_parameters; ++i) {
+ const double tmp_lre =
+ -std::log10(std::fabs(final_parameters(i) - initial_parameters(i)) /
+ std::fabs(final_parameters(i)));
+ // The maximum LRE is capped at 11 - the precision at which the
+ // ground truth is known.
+ //
+ // The minimum LRE is capped at 0 - no digits match between the
+ // computed solution and the ground truth.
+ log_relative_error =
+ std::min(log_relative_error,
+ std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
}
- std::cerr << "start " << start + 1 << " " ;
- if (num_matching_digits <= kMinNumMatchingDigits) {
- std::cerr << "FAILURE";
- } else {
- std::cerr << "SUCCESS";
+ const int kMinNumMatchingDigits = 4;
+ if (log_relative_error >= kMinNumMatchingDigits) {
++num_success;
}
- std::cerr << " summary: "
- << summary.BriefReport()
- << " Certified cost: " << certified_cost
- << std::endl;
+ printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
+ "certified cost: %e total iterations: %d\n",
+ start + 1,
+ log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
+ log_relative_error,
+ summary.initial_cost,
+ summary.final_cost,
+ nist_problem.certified_cost(),
+ (summary.num_successful_steps + summary.num_unsuccessful_steps));
}
-
return num_success;
}
void SetMinimizerOptions(ceres::Solver::Options* options) {
+ CHECK(ceres::StringToMinimizerType(FLAGS_minimizer,
+ &options->minimizer_type));
CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
&options->linear_solver_type));
CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
@@ -410,10 +491,28 @@ void SetMinimizerOptions(ceres::Solver::Options* options) {
FLAGS_trust_region_strategy,
&options->trust_region_strategy_type));
CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
+ CHECK(ceres::StringToLineSearchDirectionType(
+ FLAGS_line_search_direction,
+ &options->line_search_direction_type));
+ CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
+ &options->line_search_type));
+ CHECK(ceres::StringToLineSearchInterpolationType(
+ FLAGS_line_search_interpolation,
+ &options->line_search_interpolation_type));
options->max_num_iterations = FLAGS_num_iterations;
options->use_nonmonotonic_steps = FLAGS_nonmonotonic_steps;
options->initial_trust_region_radius = FLAGS_initial_trust_region_radius;
+ options->max_lbfgs_rank = FLAGS_lbfgs_rank;
+ options->line_search_sufficient_function_decrease = FLAGS_sufficient_decrease;
+ options->line_search_sufficient_curvature_decrease =
+ FLAGS_sufficient_curvature_decrease;
+ options->max_num_line_search_step_size_iterations =
+ FLAGS_max_line_search_iterations;
+ options->max_num_line_search_direction_restarts =
+ FLAGS_max_line_search_restarts;
+ options->use_approximate_eigenvalue_bfgs_scaling =
+ FLAGS_approximate_eigenvalue_bfgs_scaling;
options->function_tolerance = 1e-18;
options->gradient_tolerance = 1e-18;
options->parameter_tolerance = 1e-18;
@@ -427,7 +526,7 @@ void SolveNISTProblems() {
ceres::Solver::Options options;
SetMinimizerOptions(&options);
- std::cerr << "Lower Difficulty\n";
+ std::cout << "Lower Difficulty\n";
int easy_success = 0;
easy_success += RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options);
easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options);
@@ -438,7 +537,7 @@ void SolveNISTProblems() {
easy_success += RegressionDriver<DanWood, 1, 2>("DanWood.dat", options);
easy_success += RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options);
- std::cerr << "\nMedium Difficulty\n";
+ std::cout << "\nMedium Difficulty\n";
int medium_success = 0;
medium_success += RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options);
medium_success += RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options);
@@ -452,7 +551,7 @@ void SolveNISTProblems() {
medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options);
medium_success += RegressionDriver<ENSO, 1, 9>("ENSO.dat", options);
- std::cerr << "\nHigher Difficulty\n";
+ std::cout << "\nHigher Difficulty\n";
int hard_success = 0;
hard_success += RegressionDriver<MGH09, 1, 4>("MGH09.dat", options);
hard_success += RegressionDriver<Thurber, 1, 7>("Thurber.dat", options);
@@ -464,16 +563,19 @@ void SolveNISTProblems() {
hard_success += RegressionDriver<Rat43, 1, 4>("Rat43.dat", options);
hard_success += RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options);
- std::cerr << "\n";
- std::cerr << "Easy : " << easy_success << "/16\n";
- std::cerr << "Medium : " << medium_success << "/22\n";
- std::cerr << "Hard : " << hard_success << "/16\n";
- std::cerr << "Total : " << easy_success + medium_success + hard_success << "/54\n";
+ std::cout << "\n";
+ std::cout << "Easy : " << easy_success << "/16\n";
+ std::cout << "Medium : " << medium_success << "/22\n";
+ std::cout << "Hard : " << hard_success << "/16\n";
+ std::cout << "Total : " << easy_success + medium_success + hard_success << "/54\n";
}
+} // namespace examples
+} // namespace ceres
+
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
- SolveNISTProblems();
+ ceres::examples::SolveNISTProblems();
return 0;
};