diff options
author | Karl Shaffer <karlshaffer@google.com> | 2023-08-10 23:18:51 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-08-10 23:18:51 +0000 |
commit | 029d049e490dcd5fa609bb7632b0262d95f1bcce (patch) | |
tree | ace24ba4307d4978ee3134f7da671a77ad172da0 /src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java | |
parent | 4367a1c12f893ea7fb55036619f46d0e7b0634f3 (diff) | |
parent | 5484895ffd3d0c8337d159667cafc127c459f677 (diff) | |
download | apache-commons-math-029d049e490dcd5fa609bb7632b0262d95f1bcce.tar.gz |
Check-in commons-math 3.6.1 am: 1354beaf45 am: 0018f64b87 am: b3715644fb am: 5484895ffd
Original change: https://android-review.googlesource.com/c/platform/external/apache-commons-math/+/2702413
Change-Id: Idb04c8014ec76e9930d6d0aa22dac3b0b54333c8
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
Diffstat (limited to 'src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java b/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java new file mode 100644 index 0000000..304f661 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.fitting; + +import org.apache.commons.math3.analysis.ParametricUnivariateFunction; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; +import org.apache.commons.math3.linear.DiagonalMatrix; + +import java.util.Collection; + +/** + * Fits points to a user-defined {@link ParametricUnivariateFunction function}. + * + * @since 3.4 + */ +public class SimpleCurveFitter extends AbstractCurveFitter { + /** Function to fit. */ + private final ParametricUnivariateFunction function; + + /** Initial guess for the parameters. */ + private final double[] initialGuess; + + /** Maximum number of iterations of the optimization algorithm. */ + private final int maxIter; + + /** + * Contructor used by the factory methods. + * + * @param function Function to fit. + * @param initialGuess Initial guess. Cannot be {@code null}. Its length must be consistent with + * the number of parameters of the {@code function} to fit. + * @param maxIter Maximum number of iterations of the optimization algorithm. + */ + private SimpleCurveFitter( + ParametricUnivariateFunction function, double[] initialGuess, int maxIter) { + this.function = function; + this.initialGuess = initialGuess; + this.maxIter = maxIter; + } + + /** + * Creates a curve fitter. The maximum number of iterations of the optimization algorithm is set + * to {@link Integer#MAX_VALUE}. + * + * @param f Function to fit. + * @param start Initial guess for the parameters. Cannot be {@code null}. Its length must be + * consistent with the number of parameters of the function to fit. + * @return a curve fitter. + * @see #withStartPoint(double[]) + * @see #withMaxIterations(int) + */ + public static SimpleCurveFitter create(ParametricUnivariateFunction f, double[] start) { + return new SimpleCurveFitter(f, start, Integer.MAX_VALUE); + } + + /** + * Configure the start point (initial guess). + * + * @param newStart new start point (initial guess) + * @return a new instance. + */ + public SimpleCurveFitter withStartPoint(double[] newStart) { + return new SimpleCurveFitter(function, newStart.clone(), maxIter); + } + + /** + * Configure the maximum number of iterations. + * + * @param newMaxIter maximum number of iterations + * @return a new instance. + */ + public SimpleCurveFitter withMaxIterations(int newMaxIter) { + return new SimpleCurveFitter(function, initialGuess, newMaxIter); + } + + /** {@inheritDoc} */ + @Override + protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { + // Prepare least-squares problem. + final int len = observations.size(); + final double[] target = new double[len]; + final double[] weights = new double[len]; + + int count = 0; + for (WeightedObservedPoint obs : observations) { + target[count] = obs.getY(); + weights[count] = obs.getWeight(); + ++count; + } + + final AbstractCurveFitter.TheoreticalValuesFunction model = + new AbstractCurveFitter.TheoreticalValuesFunction(function, observations); + + // Create an optimizer for fitting the curve to the observed points. + return new LeastSquaresBuilder() + .maxEvaluations(Integer.MAX_VALUE) + .maxIterations(maxIter) + .start(initialGuess) + .target(target) + .weight(new DiagonalMatrix(weights)) + .model(model.getModelFunction(), model.getModelFunctionJacobian()) + .build(); + } +} |