diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/optimization/fitting')
6 files changed, 1271 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/CurveFitter.java b/src/main/java/org/apache/commons/math3/optimization/fitting/CurveFitter.java new file mode 100644 index 0000000..26e39f5 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/CurveFitter.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization.fitting; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction; +import org.apache.commons.math3.analysis.MultivariateMatrixFunction; +import org.apache.commons.math3.analysis.ParametricUnivariateFunction; +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; +import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction; +import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; +import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer; +import org.apache.commons.math3.optimization.PointVectorValuePair; + +/** Fitter for parametric univariate real functions y = f(x). + * <br/> + * When a univariate real function y = f(x) does depend on some + * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>, + * this class can be used to find these parameters. It does this + * by <em>fitting</em> the curve so it remains very close to a set of + * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>, + * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting + * is done by finding the parameters values that minimizes the objective + * function ∑(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is + * really a least squares problem. + * + * @param <T> Function to use for the fit. + * + * @deprecated As of 3.1 (to be removed in 4.0). + * @since 2.0 + */ +@Deprecated +public class CurveFitter<T extends ParametricUnivariateFunction> { + + /** Optimizer to use for the fitting. + * @deprecated as of 3.1 replaced by {@link #optimizer} + */ + @Deprecated + private final DifferentiableMultivariateVectorOptimizer oldOptimizer; + + /** Optimizer to use for the fitting. */ + private final MultivariateDifferentiableVectorOptimizer optimizer; + + /** Observed points. */ + private final List<WeightedObservedPoint> observations; + + /** Simple constructor. + * @param optimizer optimizer to use for the fitting + * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)} + */ + @Deprecated + public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) { + this.oldOptimizer = optimizer; + this.optimizer = null; + observations = new ArrayList<WeightedObservedPoint>(); + } + + /** Simple constructor. + * @param optimizer optimizer to use for the fitting + * @since 3.1 + */ + public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) { + this.oldOptimizer = null; + this.optimizer = optimizer; + observations = new ArrayList<WeightedObservedPoint>(); + } + + /** Add an observed (x,y) point to the sample with unit weight. + * <p>Calling this method is equivalent to call + * {@code addObservedPoint(1.0, x, y)}.</p> + * @param x abscissa of the point + * @param y observed value of the point at x, after fitting we should + * have f(x) as close as possible to this value + * @see #addObservedPoint(double, double, double) + * @see #addObservedPoint(WeightedObservedPoint) + * @see #getObservations() + */ + public void addObservedPoint(double x, double y) { + addObservedPoint(1.0, x, y); + } + + /** Add an observed weighted (x,y) point to the sample. + * @param weight weight of the observed point in the fit + * @param x abscissa of the point + * @param y observed value of the point at x, after fitting we should + * have f(x) as close as possible to this value + * @see #addObservedPoint(double, double) + * @see #addObservedPoint(WeightedObservedPoint) + * @see #getObservations() + */ + public void addObservedPoint(double weight, double x, double y) { + observations.add(new WeightedObservedPoint(weight, x, y)); + } + + /** Add an observed weighted (x,y) point to the sample. + * @param observed observed point to add + * @see #addObservedPoint(double, double) + * @see #addObservedPoint(double, double, double) + * @see #getObservations() + */ + public void addObservedPoint(WeightedObservedPoint observed) { + observations.add(observed); + } + + /** Get the observed points. + * @return observed points + * @see #addObservedPoint(double, double) + * @see #addObservedPoint(double, double, double) + * @see #addObservedPoint(WeightedObservedPoint) + */ + public WeightedObservedPoint[] getObservations() { + return observations.toArray(new WeightedObservedPoint[observations.size()]); + } + + /** + * Remove all observations. + */ + public void clearObservations() { + observations.clear(); + } + + /** + * Fit a curve. + * This method compute the coefficients of the curve that best + * fit the sample of observed points previously given through calls + * to the {@link #addObservedPoint(WeightedObservedPoint) + * addObservedPoint} method. + * + * @param f parametric function to fit. + * @param initialGuess first guess of the function parameters. + * @return the fitted parameters. + * @throws org.apache.commons.math3.exception.DimensionMismatchException + * if the start point dimension is wrong. + */ + public double[] fit(T f, final double[] initialGuess) { + return fit(Integer.MAX_VALUE, f, initialGuess); + } + + /** + * Fit a curve. + * This method compute the coefficients of the curve that best + * fit the sample of observed points previously given through calls + * to the {@link #addObservedPoint(WeightedObservedPoint) + * addObservedPoint} method. + * + * @param f parametric function to fit. + * @param initialGuess first guess of the function parameters. + * @param maxEval Maximum number of function evaluations. + * @return the fitted parameters. + * @throws org.apache.commons.math3.exception.TooManyEvaluationsException + * if the number of allowed evaluations is exceeded. + * @throws org.apache.commons.math3.exception.DimensionMismatchException + * if the start point dimension is wrong. + * @since 3.0 + */ + public double[] fit(int maxEval, T f, + final double[] initialGuess) { + // prepare least squares problem + double[] target = new double[observations.size()]; + double[] weights = new double[observations.size()]; + int i = 0; + for (WeightedObservedPoint point : observations) { + target[i] = point.getY(); + weights[i] = point.getWeight(); + ++i; + } + + // perform the fit + final PointVectorValuePair optimum; + if (optimizer == null) { + // to be removed in 4.0 + optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f), + target, weights, initialGuess); + } else { + optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f), + target, weights, initialGuess); + } + + // extract the coefficients + return optimum.getPointRef(); + } + + /** Vectorial function computing function theoretical values. */ + @Deprecated + private class OldTheoreticalValuesFunction + implements DifferentiableMultivariateVectorFunction { + /** Function to fit. */ + private final ParametricUnivariateFunction f; + + /** Simple constructor. + * @param f function to fit. + */ + OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) { + this.f = f; + } + + /** {@inheritDoc} */ + public MultivariateMatrixFunction jacobian() { + return new MultivariateMatrixFunction() { + /** {@inheritDoc} */ + public double[][] value(double[] point) { + final double[][] jacobian = new double[observations.size()][]; + + int i = 0; + for (WeightedObservedPoint observed : observations) { + jacobian[i++] = f.gradient(observed.getX(), point); + } + + return jacobian; + } + }; + } + + /** {@inheritDoc} */ + public double[] value(double[] point) { + // compute the residuals + final double[] values = new double[observations.size()]; + int i = 0; + for (WeightedObservedPoint observed : observations) { + values[i++] = f.value(observed.getX(), point); + } + + return values; + } + } + + /** Vectorial function computing function theoretical values. */ + private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction { + + /** Function to fit. */ + private final ParametricUnivariateFunction f; + + /** Simple constructor. + * @param f function to fit. + */ + TheoreticalValuesFunction(final ParametricUnivariateFunction f) { + this.f = f; + } + + /** {@inheritDoc} */ + public double[] value(double[] point) { + // compute the residuals + final double[] values = new double[observations.size()]; + int i = 0; + for (WeightedObservedPoint observed : observations) { + values[i++] = f.value(observed.getX(), point); + } + + return values; + } + + /** {@inheritDoc} */ + public DerivativeStructure[] value(DerivativeStructure[] point) { + + // extract parameters + final double[] parameters = new double[point.length]; + for (int k = 0; k < point.length; ++k) { + parameters[k] = point[k].getValue(); + } + + // compute the residuals + final DerivativeStructure[] values = new DerivativeStructure[observations.size()]; + int i = 0; + for (WeightedObservedPoint observed : observations) { + + // build the DerivativeStructure by adding first the value as a constant + // and then adding derivatives + DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters)); + for (int k = 0; k < point.length; ++k) { + vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0)); + } + + values[i++] = vi; + + } + + return values; + } + + } + +} diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java b/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java new file mode 100644 index 0000000..375f12e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization.fitting; + +import java.util.Arrays; +import java.util.Comparator; + +import org.apache.commons.math3.analysis.function.Gaussian; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; +import org.apache.commons.math3.util.FastMath; + +/** + * Fits points to a {@link + * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function. + * <p> + * Usage example: + * <pre> + * GaussianFitter fitter = new GaussianFitter( + * new LevenbergMarquardtOptimizer()); + * fitter.addObservedPoint(4.0254623, 531026.0); + * fitter.addObservedPoint(4.03128248, 984167.0); + * fitter.addObservedPoint(4.03839603, 1887233.0); + * fitter.addObservedPoint(4.04421621, 2687152.0); + * fitter.addObservedPoint(4.05132976, 3461228.0); + * fitter.addObservedPoint(4.05326982, 3580526.0); + * fitter.addObservedPoint(4.05779662, 3439750.0); + * fitter.addObservedPoint(4.0636168, 2877648.0); + * fitter.addObservedPoint(4.06943698, 2175960.0); + * fitter.addObservedPoint(4.07525716, 1447024.0); + * fitter.addObservedPoint(4.08237071, 717104.0); + * fitter.addObservedPoint(4.08366408, 620014.0); + * double[] parameters = fitter.fit(); + * </pre> + * + * @since 2.2 + * @deprecated As of 3.1 (to be removed in 4.0). + */ +@Deprecated +public class GaussianFitter extends CurveFitter<Gaussian.Parametric> { + /** + * Constructs an instance using the specified optimizer. + * + * @param optimizer Optimizer to use for the fitting. + */ + public GaussianFitter(DifferentiableMultivariateVectorOptimizer optimizer) { + super(optimizer); + } + + /** + * Fits a Gaussian function to the observed points. + * + * @param initialGuess First guess values in the following order: + * <ul> + * <li>Norm</li> + * <li>Mean</li> + * <li>Sigma</li> + * </ul> + * @return the parameters of the Gaussian function that best fits the + * observed points (in the same order as above). + * @since 3.0 + */ + public double[] fit(double[] initialGuess) { + final Gaussian.Parametric f = new Gaussian.Parametric() { + /** {@inheritDoc} */ + @Override + public double value(double x, double ... p) { + double v = Double.POSITIVE_INFINITY; + try { + v = super.value(x, p); + } catch (NotStrictlyPositiveException e) { // NOPMD + // Do nothing. + } + return v; + } + + /** {@inheritDoc} */ + @Override + public double[] gradient(double x, double ... p) { + double[] v = { Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY }; + try { + v = super.gradient(x, p); + } catch (NotStrictlyPositiveException e) { // NOPMD + // Do nothing. + } + return v; + } + }; + + return fit(f, initialGuess); + } + + /** + * Fits a Gaussian function to the observed points. + * + * @return the parameters of the Gaussian function that best fits the + * observed points (in the same order as above). + */ + public double[] fit() { + final double[] guess = (new ParameterGuesser(getObservations())).guess(); + return fit(guess); + } + + /** + * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} + * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} + * based on the specified observed points. + */ + public static class ParameterGuesser { + /** Normalization factor. */ + private final double norm; + /** Mean. */ + private final double mean; + /** Standard deviation. */ + private final double sigma; + + /** + * Constructs instance with the specified observed points. + * + * @param observations Observed points from which to guess the + * parameters of the Gaussian. + * @throws NullArgumentException if {@code observations} is + * {@code null}. + * @throws NumberIsTooSmallException if there are less than 3 + * observations. + */ + public ParameterGuesser(WeightedObservedPoint[] observations) { + if (observations == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + if (observations.length < 3) { + throw new NumberIsTooSmallException(observations.length, 3, true); + } + + final WeightedObservedPoint[] sorted = sortObservations(observations); + final double[] params = basicGuess(sorted); + + norm = params[0]; + mean = params[1]; + sigma = params[2]; + } + + /** + * Gets an estimation of the parameters. + * + * @return the guessed parameters, in the following order: + * <ul> + * <li>Normalization factor</li> + * <li>Mean</li> + * <li>Standard deviation</li> + * </ul> + */ + public double[] guess() { + return new double[] { norm, mean, sigma }; + } + + /** + * Sort the observations. + * + * @param unsorted Input observations. + * @return the input observations, sorted. + */ + private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { + final WeightedObservedPoint[] observations = unsorted.clone(); + final Comparator<WeightedObservedPoint> cmp + = new Comparator<WeightedObservedPoint>() { + /** {@inheritDoc} */ + public int compare(WeightedObservedPoint p1, + WeightedObservedPoint p2) { + if (p1 == null && p2 == null) { + return 0; + } + if (p1 == null) { + return -1; + } + if (p2 == null) { + return 1; + } + final int cmpX = Double.compare(p1.getX(), p2.getX()); + if (cmpX < 0) { + return -1; + } + if (cmpX > 0) { + return 1; + } + final int cmpY = Double.compare(p1.getY(), p2.getY()); + if (cmpY < 0) { + return -1; + } + if (cmpY > 0) { + return 1; + } + final int cmpW = Double.compare(p1.getWeight(), p2.getWeight()); + if (cmpW < 0) { + return -1; + } + if (cmpW > 0) { + return 1; + } + return 0; + } + }; + + Arrays.sort(observations, cmp); + return observations; + } + + /** + * Guesses the parameters based on the specified observed points. + * + * @param points Observed points, sorted. + * @return the guessed parameters (normalization factor, mean and + * sigma). + */ + private double[] basicGuess(WeightedObservedPoint[] points) { + final int maxYIdx = findMaxY(points); + final double n = points[maxYIdx].getY(); + final double m = points[maxYIdx].getX(); + + double fwhmApprox; + try { + final double halfY = n + ((m - n) / 2); + final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); + final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); + fwhmApprox = fwhmX2 - fwhmX1; + } catch (OutOfRangeException e) { + // TODO: Exceptions should not be used for flow control. + fwhmApprox = points[points.length - 1].getX() - points[0].getX(); + } + final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); + + return new double[] { n, m, s }; + } + + /** + * Finds index of point in specified points with the largest Y. + * + * @param points Points to search. + * @return the index in specified points array. + */ + private int findMaxY(WeightedObservedPoint[] points) { + int maxYIdx = 0; + for (int i = 1; i < points.length; i++) { + if (points[i].getY() > points[maxYIdx].getY()) { + maxYIdx = i; + } + } + return maxYIdx; + } + + /** + * Interpolates using the specified points to determine X at the + * specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start the search for + * interpolation bounds points. + * @param idxStep Index step for searching interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the value of X for the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + private double interpolateXAtY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) + throws OutOfRangeException { + if (idxStep == 0) { + throw new ZeroException(); + } + final WeightedObservedPoint[] twoPoints + = getInterpolationPointsForY(points, startIdx, idxStep, y); + final WeightedObservedPoint p1 = twoPoints[0]; + final WeightedObservedPoint p2 = twoPoints[1]; + if (p1.getY() == y) { + return p1.getX(); + } + if (p2.getY() == y) { + return p2.getX(); + } + return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / + (p2.getY() - p1.getY())); + } + + /** + * Gets the two bounding interpolation points from the specified points + * suitable for determining X at the specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start search for + * interpolation bounds points. + * @param idxStep Index step for search for interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the array containing two points suitable for determining X at + * the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) + throws OutOfRangeException { + if (idxStep == 0) { + throw new ZeroException(); + } + for (int i = startIdx; + idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; + i += idxStep) { + final WeightedObservedPoint p1 = points[i]; + final WeightedObservedPoint p2 = points[i + idxStep]; + if (isBetween(y, p1.getY(), p2.getY())) { + if (idxStep < 0) { + return new WeightedObservedPoint[] { p2, p1 }; + } else { + return new WeightedObservedPoint[] { p1, p2 }; + } + } + } + + // Boundaries are replaced by dummy values because the raised + // exception is caught and the message never displayed. + // TODO: Exceptions should not be used for flow control. + throw new OutOfRangeException(y, + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY); + } + + /** + * Determines whether a value is between two other values. + * + * @param value Value to test whether it is between {@code boundary1} + * and {@code boundary2}. + * @param boundary1 One end of the range. + * @param boundary2 Other end of the range. + * @return {@code true} if {@code value} is between {@code boundary1} and + * {@code boundary2} (inclusive), {@code false} otherwise. + */ + private boolean isBetween(double value, + double boundary1, + double boundary2) { + return (value >= boundary1 && value <= boundary2) || + (value >= boundary2 && value <= boundary1); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/HarmonicFitter.java b/src/main/java/org/apache/commons/math3/optimization/fitting/HarmonicFitter.java new file mode 100644 index 0000000..85c6d18 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/HarmonicFitter.java @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization.fitting; + +import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; +import org.apache.commons.math3.analysis.function.HarmonicOscillator; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; + +/** + * Class that implements a curve fitting specialized for sinusoids. + * + * Harmonic fitting is a very simple case of curve fitting. The + * estimated coefficients are the amplitude a, the pulsation ω and + * the phase φ: <code>f (t) = a cos (ω t + φ)</code>. They are + * searched by a least square estimator initialized with a rough guess + * based on integrals. + * + * @deprecated As of 3.1 (to be removed in 4.0). + * @since 2.0 + */ +@Deprecated +public class HarmonicFitter extends CurveFitter<HarmonicOscillator.Parametric> { + /** + * Simple constructor. + * @param optimizer Optimizer to use for the fitting. + */ + public HarmonicFitter(final DifferentiableMultivariateVectorOptimizer optimizer) { + super(optimizer); + } + + /** + * Fit an harmonic function to the observed points. + * + * @param initialGuess First guess values in the following order: + * <ul> + * <li>Amplitude</li> + * <li>Angular frequency</li> + * <li>Phase</li> + * </ul> + * @return the parameters of the harmonic function that best fits the + * observed points (in the same order as above). + */ + public double[] fit(double[] initialGuess) { + return fit(new HarmonicOscillator.Parametric(), initialGuess); + } + + /** + * Fit an harmonic function to the observed points. + * An initial guess will be automatically computed. + * + * @return the parameters of the harmonic function that best fits the + * observed points (see the other {@link #fit(double[]) fit} method. + * @throws NumberIsTooSmallException if the sample is too short for the + * the first guess to be computed. + * @throws ZeroException if the first guess cannot be computed because + * the abscissa range is zero. + */ + public double[] fit() { + return fit((new ParameterGuesser(getObservations())).guess()); + } + + /** + * This class guesses harmonic coefficients from a sample. + * <p>The algorithm used to guess the coefficients is as follows:</p> + * + * <p>We know f (t) at some sampling points t<sub>i</sub> and want to find a, + * ω and φ such that f (t) = a cos (ω t + φ). + * </p> + * + * <p>From the analytical expression, we can compute two primitives : + * <pre> + * If2 (t) = ∫ f<sup>2</sup> = a<sup>2</sup> × [t + S (t)] / 2 + * If'2 (t) = ∫ f'<sup>2</sup> = a<sup>2</sup> ω<sup>2</sup> × [t - S (t)] / 2 + * where S (t) = sin (2 (ω t + φ)) / (2 ω) + * </pre> + * </p> + * + * <p>We can remove S between these expressions : + * <pre> + * If'2 (t) = a<sup>2</sup> ω<sup>2</sup> t - ω<sup>2</sup> If2 (t) + * </pre> + * </p> + * + * <p>The preceding expression shows that If'2 (t) is a linear + * combination of both t and If2 (t): If'2 (t) = A × t + B × If2 (t) + * </p> + * + * <p>From the primitive, we can deduce the same form for definite + * integrals between t<sub>1</sub> and t<sub>i</sub> for each t<sub>i</sub> : + * <pre> + * If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>) = A × (t<sub>i</sub> - t<sub>1</sub>) + B × (If2 (t<sub>i</sub>) - If2 (t<sub>1</sub>)) + * </pre> + * </p> + * + * <p>We can find the coefficients A and B that best fit the sample + * to this linear expression by computing the definite integrals for + * each sample points. + * </p> + * + * <p>For a bilinear expression z (x<sub>i</sub>, y<sub>i</sub>) = A × x<sub>i</sub> + B × y<sub>i</sub>, the + * coefficients A and B that minimize a least square criterion + * ∑ (z<sub>i</sub> - z (x<sub>i</sub>, y<sub>i</sub>))<sup>2</sup> are given by these expressions:</p> + * <pre> + * + * ∑y<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> + * A = ------------------------ + * ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub> + * + * ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> + * B = ------------------------ + * ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub> + * </pre> + * </p> + * + * + * <p>In fact, we can assume both a and ω are positive and + * compute them directly, knowing that A = a<sup>2</sup> ω<sup>2</sup> and that + * B = - ω<sup>2</sup>. The complete algorithm is therefore:</p> + * <pre> + * + * for each t<sub>i</sub> from t<sub>1</sub> to t<sub>n-1</sub>, compute: + * f (t<sub>i</sub>) + * f' (t<sub>i</sub>) = (f (t<sub>i+1</sub>) - f(t<sub>i-1</sub>)) / (t<sub>i+1</sub> - t<sub>i-1</sub>) + * x<sub>i</sub> = t<sub>i</sub> - t<sub>1</sub> + * y<sub>i</sub> = ∫ f<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub> + * z<sub>i</sub> = ∫ f'<sup>2</sup> from t<sub>1</sub> to t<sub>i</sub> + * update the sums ∑x<sub>i</sub>x<sub>i</sub>, ∑y<sub>i</sub>y<sub>i</sub>, ∑x<sub>i</sub>y<sub>i</sub>, ∑x<sub>i</sub>z<sub>i</sub> and ∑y<sub>i</sub>z<sub>i</sub> + * end for + * + * |-------------------------- + * \ | ∑y<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> + * a = \ | ------------------------ + * \| ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> + * + * + * |-------------------------- + * \ | ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>z<sub>i</sub> - ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>z<sub>i</sub> + * ω = \ | ------------------------ + * \| ∑x<sub>i</sub>x<sub>i</sub> ∑y<sub>i</sub>y<sub>i</sub> - ∑x<sub>i</sub>y<sub>i</sub> ∑x<sub>i</sub>y<sub>i</sub> + * + * </pre> + * </p> + * + * <p>Once we know ω, we can compute: + * <pre> + * fc = ω f (t) cos (ω t) - f' (t) sin (ω t) + * fs = ω f (t) sin (ω t) + f' (t) cos (ω t) + * </pre> + * </p> + * + * <p>It appears that <code>fc = a ω cos (φ)</code> and + * <code>fs = -a ω sin (φ)</code>, so we can use these + * expressions to compute φ. The best estimate over the sample is + * given by averaging these expressions. + * </p> + * + * <p>Since integrals and means are involved in the preceding + * estimations, these operations run in O(n) time, where n is the + * number of measurements.</p> + */ + public static class ParameterGuesser { + /** Amplitude. */ + private final double a; + /** Angular frequency. */ + private final double omega; + /** Phase. */ + private final double phi; + + /** + * Simple constructor. + * + * @param observations Sampled observations. + * @throws NumberIsTooSmallException if the sample is too short. + * @throws ZeroException if the abscissa range is zero. + * @throws MathIllegalStateException when the guessing procedure cannot + * produce sensible results. + */ + public ParameterGuesser(WeightedObservedPoint[] observations) { + if (observations.length < 4) { + throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, + observations.length, 4, true); + } + + final WeightedObservedPoint[] sorted = sortObservations(observations); + + final double aOmega[] = guessAOmega(sorted); + a = aOmega[0]; + omega = aOmega[1]; + + phi = guessPhi(sorted); + } + + /** + * Gets an estimation of the parameters. + * + * @return the guessed parameters, in the following order: + * <ul> + * <li>Amplitude</li> + * <li>Angular frequency</li> + * <li>Phase</li> + * </ul> + */ + public double[] guess() { + return new double[] { a, omega, phi }; + } + + /** + * Sort the observations with respect to the abscissa. + * + * @param unsorted Input observations. + * @return the input observations, sorted. + */ + private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { + final WeightedObservedPoint[] observations = unsorted.clone(); + + // Since the samples are almost always already sorted, this + // method is implemented as an insertion sort that reorders the + // elements in place. Insertion sort is very efficient in this case. + WeightedObservedPoint curr = observations[0]; + for (int j = 1; j < observations.length; ++j) { + WeightedObservedPoint prec = curr; + curr = observations[j]; + if (curr.getX() < prec.getX()) { + // the current element should be inserted closer to the beginning + int i = j - 1; + WeightedObservedPoint mI = observations[i]; + while ((i >= 0) && (curr.getX() < mI.getX())) { + observations[i + 1] = mI; + if (i-- != 0) { + mI = observations[i]; + } + } + observations[i + 1] = curr; + curr = observations[j]; + } + } + + return observations; + } + + /** + * Estimate a first guess of the amplitude and angular frequency. + * This method assumes that the {@link #sortObservations(WeightedObservedPoint[])} method + * has been called previously. + * + * @param observations Observations, sorted w.r.t. abscissa. + * @throws ZeroException if the abscissa range is zero. + * @throws MathIllegalStateException when the guessing procedure cannot + * produce sensible results. + * @return the guessed amplitude (at index 0) and circular frequency + * (at index 1). + */ + private double[] guessAOmega(WeightedObservedPoint[] observations) { + final double[] aOmega = new double[2]; + + // initialize the sums for the linear model between the two integrals + double sx2 = 0; + double sy2 = 0; + double sxy = 0; + double sxz = 0; + double syz = 0; + + double currentX = observations[0].getX(); + double currentY = observations[0].getY(); + double f2Integral = 0; + double fPrime2Integral = 0; + final double startX = currentX; + for (int i = 1; i < observations.length; ++i) { + // one step forward + final double previousX = currentX; + final double previousY = currentY; + currentX = observations[i].getX(); + currentY = observations[i].getY(); + + // update the integrals of f<sup>2</sup> and f'<sup>2</sup> + // considering a linear model for f (and therefore constant f') + final double dx = currentX - previousX; + final double dy = currentY - previousY; + final double f2StepIntegral = + dx * (previousY * previousY + previousY * currentY + currentY * currentY) / 3; + final double fPrime2StepIntegral = dy * dy / dx; + + final double x = currentX - startX; + f2Integral += f2StepIntegral; + fPrime2Integral += fPrime2StepIntegral; + + sx2 += x * x; + sy2 += f2Integral * f2Integral; + sxy += x * f2Integral; + sxz += x * fPrime2Integral; + syz += f2Integral * fPrime2Integral; + } + + // compute the amplitude and pulsation coefficients + double c1 = sy2 * sxz - sxy * syz; + double c2 = sxy * sxz - sx2 * syz; + double c3 = sx2 * sy2 - sxy * sxy; + if ((c1 / c2 < 0) || (c2 / c3 < 0)) { + final int last = observations.length - 1; + // Range of the observations, assuming that the + // observations are sorted. + final double xRange = observations[last].getX() - observations[0].getX(); + if (xRange == 0) { + throw new ZeroException(); + } + aOmega[1] = 2 * Math.PI / xRange; + + double yMin = Double.POSITIVE_INFINITY; + double yMax = Double.NEGATIVE_INFINITY; + for (int i = 1; i < observations.length; ++i) { + final double y = observations[i].getY(); + if (y < yMin) { + yMin = y; + } + if (y > yMax) { + yMax = y; + } + } + aOmega[0] = 0.5 * (yMax - yMin); + } else { + if (c2 == 0) { + // In some ill-conditioned cases (cf. MATH-844), the guesser + // procedure cannot produce sensible results. + throw new MathIllegalStateException(LocalizedFormats.ZERO_DENOMINATOR); + } + + aOmega[0] = FastMath.sqrt(c1 / c2); + aOmega[1] = FastMath.sqrt(c2 / c3); + } + + return aOmega; + } + + /** + * Estimate a first guess of the phase. + * + * @param observations Observations, sorted w.r.t. abscissa. + * @return the guessed phase. + */ + private double guessPhi(WeightedObservedPoint[] observations) { + // initialize the means + double fcMean = 0; + double fsMean = 0; + + double currentX = observations[0].getX(); + double currentY = observations[0].getY(); + for (int i = 1; i < observations.length; ++i) { + // one step forward + final double previousX = currentX; + final double previousY = currentY; + currentX = observations[i].getX(); + currentY = observations[i].getY(); + final double currentYPrime = (currentY - previousY) / (currentX - previousX); + + double omegaX = omega * currentX; + double cosine = FastMath.cos(omegaX); + double sine = FastMath.sin(omegaX); + fcMean += omega * currentY * cosine - currentYPrime * sine; + fsMean += omega * currentY * sine + currentYPrime * cosine; + } + + return FastMath.atan2(-fsMean, fcMean); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/PolynomialFitter.java b/src/main/java/org/apache/commons/math3/optimization/fitting/PolynomialFitter.java new file mode 100644 index 0000000..dbefcc2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/PolynomialFitter.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization.fitting; + +import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; +import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; + +/** + * Polynomial fitting is a very simple case of {@link CurveFitter curve fitting}. + * The estimated coefficients are the polynomial coefficients (see the + * {@link #fit(double[]) fit} method). + * + * @deprecated As of 3.1 (to be removed in 4.0). + * @since 2.0 + */ +@Deprecated +public class PolynomialFitter extends CurveFitter<PolynomialFunction.Parametric> { + /** Polynomial degree. + * @deprecated + */ + @Deprecated + private final int degree; + + /** + * Simple constructor. + * <p>The polynomial fitter built this way are complete polynomials, + * ie. a n-degree polynomial has n+1 coefficients.</p> + * + * @param degree Maximal degree of the polynomial. + * @param optimizer Optimizer to use for the fitting. + * @deprecated Since 3.1 (to be removed in 4.0). Please use + * {@link #PolynomialFitter(DifferentiableMultivariateVectorOptimizer)} instead. + */ + @Deprecated + public PolynomialFitter(int degree, final DifferentiableMultivariateVectorOptimizer optimizer) { + super(optimizer); + this.degree = degree; + } + + /** + * Simple constructor. + * + * @param optimizer Optimizer to use for the fitting. + * @since 3.1 + */ + public PolynomialFitter(DifferentiableMultivariateVectorOptimizer optimizer) { + super(optimizer); + degree = -1; // To avoid compilation error until the instance variable is removed. + } + + /** + * Get the polynomial fitting the weighted (x, y) points. + * + * @return the coefficients of the polynomial that best fits the observed points. + * @throws org.apache.commons.math3.exception.ConvergenceException + * if the algorithm failed to converge. + * @deprecated Since 3.1 (to be removed in 4.0). Please use {@link #fit(double[])} instead. + */ + @Deprecated + public double[] fit() { + return fit(new PolynomialFunction.Parametric(), new double[degree + 1]); + } + + /** + * Get the coefficients of the polynomial fitting the weighted data points. + * The degree of the fitting polynomial is {@code guess.length - 1}. + * + * @param guess First guess for the coefficients. They must be sorted in + * increasing order of the polynomial's degree. + * @param maxEval Maximum number of evaluations of the polynomial. + * @return the coefficients of the polynomial that best fits the observed points. + * @throws org.apache.commons.math3.exception.TooManyEvaluationsException if + * the number of evaluations exceeds {@code maxEval}. + * @throws org.apache.commons.math3.exception.ConvergenceException + * if the algorithm failed to converge. + * @since 3.1 + */ + public double[] fit(int maxEval, double[] guess) { + return fit(maxEval, new PolynomialFunction.Parametric(), guess); + } + + /** + * Get the coefficients of the polynomial fitting the weighted data points. + * The degree of the fitting polynomial is {@code guess.length - 1}. + * + * @param guess First guess for the coefficients. They must be sorted in + * increasing order of the polynomial's degree. + * @return the coefficients of the polynomial that best fits the observed points. + * @throws org.apache.commons.math3.exception.ConvergenceException + * if the algorithm failed to converge. + * @since 3.1 + */ + public double[] fit(double[] guess) { + return fit(new PolynomialFunction.Parametric(), guess); + } +} diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/WeightedObservedPoint.java b/src/main/java/org/apache/commons/math3/optimization/fitting/WeightedObservedPoint.java new file mode 100644 index 0000000..899a502 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/WeightedObservedPoint.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization.fitting; + +import java.io.Serializable; + +/** This class is a simple container for weighted observed point in + * {@link CurveFitter curve fitting}. + * <p>Instances of this class are guaranteed to be immutable.</p> + * @deprecated As of 3.1 (to be removed in 4.0). + * @since 2.0 + */ +@Deprecated +public class WeightedObservedPoint implements Serializable { + + /** Serializable version id. */ + private static final long serialVersionUID = 5306874947404636157L; + + /** Weight of the measurement in the fitting process. */ + private final double weight; + + /** Abscissa of the point. */ + private final double x; + + /** Observed value of the function at x. */ + private final double y; + + /** Simple constructor. + * @param weight weight of the measurement in the fitting process + * @param x abscissa of the measurement + * @param y ordinate of the measurement + */ + public WeightedObservedPoint(final double weight, final double x, final double y) { + this.weight = weight; + this.x = x; + this.y = y; + } + + /** Get the weight of the measurement in the fitting process. + * @return weight of the measurement in the fitting process + */ + public double getWeight() { + return weight; + } + + /** Get the abscissa of the point. + * @return abscissa of the point + */ + public double getX() { + return x; + } + + /** Get the observed value of the function at x. + * @return observed value of the function at x + */ + public double getY() { + return y; + } + +} + diff --git a/src/main/java/org/apache/commons/math3/optimization/fitting/package-info.java b/src/main/java/org/apache/commons/math3/optimization/fitting/package-info.java new file mode 100644 index 0000000..b25e5fd --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/fitting/package-info.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * This package provides classes to perform curve fitting. + * + * <p>Curve fitting is a special case of a least squares problem + * were the parameters are the coefficients of a function <code>f</code> + * whose graph <code>y=f(x)</code> should pass through sample points, and + * were the objective function is the squared sum of residuals + * <code>f(x<sub>i</sub>)-y<sub>i</sub></code> for observed points + * (x<sub>i</sub>, y<sub>i</sub>).</p> + * + * + */ +package org.apache.commons.math3.optimization.fitting; |