diff options
Diffstat (limited to 'examples/nist.cc')
-rw-r--r-- | examples/nist.cc | 202 |
1 files changed, 152 insertions, 50 deletions
diff --git a/examples/nist.cc b/examples/nist.cc index 440ab5c..1773a0f 100644 --- a/examples/nist.cc +++ b/examples/nist.cc @@ -28,29 +28,61 @@ // // Author: sameeragarwal@google.com (Sameer Agarwal) // -// NIST non-linear regression problems solved using Ceres. +// The National Institute of Standards and Technology has released a +// set of problems to test non-linear least squares solvers. // -// The data was obtained from -// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml, where more -// background on these problems can also be found. +// More information about the background on these problems and +// suggested evaluation methodology can be found at: // -// Currently not all problems are solved successfully. Some of the -// failures are due to convergence to a local minimum, and some fail -// because of numerical issues. +// http://www.itl.nist.gov/div898/strd/nls/nls_info.shtml // -// TODO(sameeragarwal): Fix numerical issues so that all the problems -// converge and then look at convergence to the wrong solution issues. +// The problem data themselves can be found at +// +// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml +// +// The problems are divided into three levels of difficulty, Easy, +// Medium and Hard. For each problem there are two starting guesses, +// the first one far away from the global minimum and the second +// closer to it. +// +// A problem is considered successfully solved, if every components of +// the solution matches the globally optimal solution in at least 4 +// digits or more. +// +// This dataset was used for an evaluation of Non-linear least squares +// solvers: +// +// P. F. Mondragon & B. Borchers, A Comparison of Nonlinear Regression +// Codes, Journal of Modern Applied Statistical Methods, 4(1):343-351, +// 2005. +// +// The results from Mondragon & Borchers can be summarized as +// Excel Gnuplot GaussFit HBN MinPack +// Average LRE 2.3 4.3 4.0 6.8 4.4 +// Winner 1 5 12 29 12 +// +// Where the row Winner counts, the number of problems for which the +// solver had the highest LRE. + +// In this file, we implement the same evaluation methodology using +// Ceres. Currently using Levenberg-Marquard with DENSE_QR, we get +// +// Excel Gnuplot GaussFit HBN MinPack Ceres +// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4 +// Winner 0 0 5 11 2 41 #include <iostream> +#include <iterator> #include <fstream> #include "ceres/ceres.h" -#include "ceres/split.h" #include "gflags/gflags.h" #include "glog/logging.h" #include "Eigen/Core" DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear" "regression examples"); +DEFINE_string(minimizer, "trust_region", + "Minimizer type to use, choices are: line_search & trust_region"); DEFINE_string(trust_region_strategy, "levenberg_marquardt", "Options are: levenberg_marquardt, dogleg"); DEFINE_string(dogleg, "traditional_dogleg", @@ -60,21 +92,63 @@ DEFINE_string(linear_solver, "dense_qr", "Options are: " "cgnr"); DEFINE_string(preconditioner, "jacobi", "Options are: " "identity, jacobi"); +DEFINE_string(line_search, "armijo", + "Line search algorithm to use, choices are: armijo and wolfe."); +DEFINE_string(line_search_direction, "lbfgs", + "Line search direction algorithm to use, choices: lbfgs, bfgs"); +DEFINE_int32(max_line_search_iterations, 20, + "Maximum number of iterations for each line search."); +DEFINE_int32(max_line_search_restarts, 10, + "Maximum number of restarts of line search direction algorithm."); +DEFINE_string(line_search_interpolation, "cubic", + "Degree of polynomial aproximation in line search, " + "choices are: bisection, quadratic & cubic."); +DEFINE_int32(lbfgs_rank, 20, + "Rank of L-BFGS inverse Hessian approximation in line search."); +DEFINE_bool(approximate_eigenvalue_bfgs_scaling, false, + "Use approximate eigenvalue scaling in (L)BFGS line search."); +DEFINE_double(sufficient_decrease, 1.0e-4, + "Line search Armijo sufficient (function) decrease factor."); +DEFINE_double(sufficient_curvature_decrease, 0.9, + "Line search Wolfe sufficient curvature decrease factor."); DEFINE_int32(num_iterations, 10000, "Number of iterations"); DEFINE_bool(nonmonotonic_steps, false, "Trust region algorithm can use" " nonmonotic steps"); DEFINE_double(initial_trust_region_radius, 1e4, "Initial trust region radius"); +namespace ceres { +namespace examples { + using Eigen::Dynamic; using Eigen::RowMajor; typedef Eigen::Matrix<double, Dynamic, 1> Vector; typedef Eigen::Matrix<double, Dynamic, Dynamic, RowMajor> Matrix; +void SplitStringUsingChar(const string& full, + const char delim, + vector<string>* result) { + back_insert_iterator< vector<string> > it(*result); + + const char* p = full.data(); + const char* end = p + full.size(); + while (p != end) { + if (*p == delim) { + ++p; + } else { + const char* start = p; + while (++p != end && *p != delim) { + // Skip to the next occurence of the delimiter. + } + *it++ = string(start, p - start); + } + } +} + bool GetAndSplitLine(std::ifstream& ifs, std::vector<std::string>* pieces) { pieces->clear(); char buf[256]; ifs.getline(buf, 256); - ceres::SplitStringUsing(std::string(buf), " ", pieces); + SplitStringUsingChar(std::string(buf), ' ', pieces); return true; } @@ -339,7 +413,7 @@ struct Nelson { template <typename Model, int num_residuals, int num_parameters> int RegressionDriver(const std::string& filename, - const ceres::Solver::Options& options) { + const ceres::Solver::Options& options) { NISTProblem nist_problem(FLAGS_nist_data_dir + filename); CHECK_EQ(num_residuals, nist_problem.response_size()); CHECK_EQ(num_parameters, nist_problem.num_parameters()); @@ -347,11 +421,12 @@ int RegressionDriver(const std::string& filename, Matrix predictor = nist_problem.predictor(); Matrix response = nist_problem.response(); Matrix final_parameters = nist_problem.final_parameters(); - std::vector<ceres::Solver::Summary> summaries(nist_problem.num_starts() + 1); - std::cerr << filename << std::endl; + + printf("%s\n", filename.c_str()); // Each NIST problem comes with multiple starting points, so we // construct the problem from scratch for each case and solve it. + int num_success = 0; for (int start = 0; start < nist_problem.num_starts(); ++start) { Matrix initial_parameters = nist_problem.initial_parameters(start); @@ -365,43 +440,49 @@ int RegressionDriver(const std::string& filename, initial_parameters.data()); } - Solve(options, &problem, &summaries[start]); - } - - const double certified_cost = nist_problem.certified_cost(); - - int num_success = 0; - const int kMinNumMatchingDigits = 4; - for (int start = 0; start < nist_problem.num_starts(); ++start) { - const ceres::Solver::Summary& summary = summaries[start]; - - int num_matching_digits = 0; - if (IsSuccessfulTermination(summary.termination_type) - && summary.final_cost < certified_cost) { - num_matching_digits = kMinNumMatchingDigits + 1; - } else { - num_matching_digits = - -std::log10(fabs(summary.final_cost - certified_cost) / certified_cost); + ceres::Solver::Summary summary; + Solve(options, &problem, &summary); + + // Compute the LRE by comparing each component of the solution + // with the ground truth, and taking the minimum. + Matrix final_parameters = nist_problem.final_parameters(); + const double kMaxNumSignificantDigits = 11; + double log_relative_error = kMaxNumSignificantDigits + 1; + for (int i = 0; i < num_parameters; ++i) { + const double tmp_lre = + -std::log10(std::fabs(final_parameters(i) - initial_parameters(i)) / + std::fabs(final_parameters(i))); + // The maximum LRE is capped at 11 - the precision at which the + // ground truth is known. + // + // The minimum LRE is capped at 0 - no digits match between the + // computed solution and the ground truth. + log_relative_error = + std::min(log_relative_error, + std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre))); } - std::cerr << "start " << start + 1 << " " ; - if (num_matching_digits <= kMinNumMatchingDigits) { - std::cerr << "FAILURE"; - } else { - std::cerr << "SUCCESS"; + const int kMinNumMatchingDigits = 4; + if (log_relative_error >= kMinNumMatchingDigits) { ++num_success; } - std::cerr << " summary: " - << summary.BriefReport() - << " Certified cost: " << certified_cost - << std::endl; + printf("start: %d status: %s lre: %4.1f initial cost: %e final cost:%e " + "certified cost: %e total iterations: %d\n", + start + 1, + log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS", + log_relative_error, + summary.initial_cost, + summary.final_cost, + nist_problem.certified_cost(), + (summary.num_successful_steps + summary.num_unsuccessful_steps)); } - return num_success; } void SetMinimizerOptions(ceres::Solver::Options* options) { + CHECK(ceres::StringToMinimizerType(FLAGS_minimizer, + &options->minimizer_type)); CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver, &options->linear_solver_type)); CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner, @@ -410,10 +491,28 @@ void SetMinimizerOptions(ceres::Solver::Options* options) { FLAGS_trust_region_strategy, &options->trust_region_strategy_type)); CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type)); + CHECK(ceres::StringToLineSearchDirectionType( + FLAGS_line_search_direction, + &options->line_search_direction_type)); + CHECK(ceres::StringToLineSearchType(FLAGS_line_search, + &options->line_search_type)); + CHECK(ceres::StringToLineSearchInterpolationType( + FLAGS_line_search_interpolation, + &options->line_search_interpolation_type)); options->max_num_iterations = FLAGS_num_iterations; options->use_nonmonotonic_steps = FLAGS_nonmonotonic_steps; options->initial_trust_region_radius = FLAGS_initial_trust_region_radius; + options->max_lbfgs_rank = FLAGS_lbfgs_rank; + options->line_search_sufficient_function_decrease = FLAGS_sufficient_decrease; + options->line_search_sufficient_curvature_decrease = + FLAGS_sufficient_curvature_decrease; + options->max_num_line_search_step_size_iterations = + FLAGS_max_line_search_iterations; + options->max_num_line_search_direction_restarts = + FLAGS_max_line_search_restarts; + options->use_approximate_eigenvalue_bfgs_scaling = + FLAGS_approximate_eigenvalue_bfgs_scaling; options->function_tolerance = 1e-18; options->gradient_tolerance = 1e-18; options->parameter_tolerance = 1e-18; @@ -427,7 +526,7 @@ void SolveNISTProblems() { ceres::Solver::Options options; SetMinimizerOptions(&options); - std::cerr << "Lower Difficulty\n"; + std::cout << "Lower Difficulty\n"; int easy_success = 0; easy_success += RegressionDriver<Misra1a, 1, 2>("Misra1a.dat", options); easy_success += RegressionDriver<Chwirut, 1, 3>("Chwirut1.dat", options); @@ -438,7 +537,7 @@ void SolveNISTProblems() { easy_success += RegressionDriver<DanWood, 1, 2>("DanWood.dat", options); easy_success += RegressionDriver<Misra1b, 1, 2>("Misra1b.dat", options); - std::cerr << "\nMedium Difficulty\n"; + std::cout << "\nMedium Difficulty\n"; int medium_success = 0; medium_success += RegressionDriver<Kirby2, 1, 5>("Kirby2.dat", options); medium_success += RegressionDriver<Hahn1, 1, 7>("Hahn1.dat", options); @@ -452,7 +551,7 @@ void SolveNISTProblems() { medium_success += RegressionDriver<Roszman1, 1, 4>("Roszman1.dat", options); medium_success += RegressionDriver<ENSO, 1, 9>("ENSO.dat", options); - std::cerr << "\nHigher Difficulty\n"; + std::cout << "\nHigher Difficulty\n"; int hard_success = 0; hard_success += RegressionDriver<MGH09, 1, 4>("MGH09.dat", options); hard_success += RegressionDriver<Thurber, 1, 7>("Thurber.dat", options); @@ -464,16 +563,19 @@ void SolveNISTProblems() { hard_success += RegressionDriver<Rat43, 1, 4>("Rat43.dat", options); hard_success += RegressionDriver<Bennet5, 1, 3>("Bennett5.dat", options); - std::cerr << "\n"; - std::cerr << "Easy : " << easy_success << "/16\n"; - std::cerr << "Medium : " << medium_success << "/22\n"; - std::cerr << "Hard : " << hard_success << "/16\n"; - std::cerr << "Total : " << easy_success + medium_success + hard_success << "/54\n"; + std::cout << "\n"; + std::cout << "Easy : " << easy_success << "/16\n"; + std::cout << "Medium : " << medium_success << "/22\n"; + std::cout << "Hard : " << hard_success << "/16\n"; + std::cout << "Total : " << easy_success + medium_success + hard_success << "/54\n"; } +} // namespace examples +} // namespace ceres + int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - SolveNISTProblems(); + ceres::examples::SolveNISTProblems(); return 0; }; |