diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/analysis/interpolation/LoessInterpolator.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/analysis/interpolation/LoessInterpolator.java | 473 |
1 files changed, 473 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/analysis/interpolation/LoessInterpolator.java b/src/main/java/org/apache/commons/math3/analysis/interpolation/LoessInterpolator.java new file mode 100644 index 0000000..3ffbe93 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/analysis/interpolation/LoessInterpolator.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.analysis.interpolation; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NonMonotonicSequenceException; +import org.apache.commons.math3.exception.NotFiniteNumberException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; + +/** + * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression"> + * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of + * real univariate functions. + * <p> + * For reference, see + * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038"> + * William S. Cleveland - Robust Locally Weighted Regression and Smoothing + * Scatterplots</a> + * </p> + * This class implements both the loess method and serves as an interpolation + * adapter to it, allowing one to build a spline on the obtained loess fit. + * + * @since 2.0 + */ +public class LoessInterpolator + implements UnivariateInterpolator, Serializable { + /** Default value of the bandwidth parameter. */ + public static final double DEFAULT_BANDWIDTH = 0.3; + /** Default value of the number of robustness iterations. */ + public static final int DEFAULT_ROBUSTNESS_ITERS = 2; + /** + * Default value for accuracy. + * @since 2.1 + */ + public static final double DEFAULT_ACCURACY = 1e-12; + /** serializable version identifier. */ + private static final long serialVersionUID = 5204927143605193821L; + /** + * The bandwidth parameter: when computing the loess fit at + * a particular point, this fraction of source points closest + * to the current point is taken into account for computing + * a least-squares regression. + * <p> + * A sensible value is usually 0.25 to 0.5.</p> + */ + private final double bandwidth; + /** + * The number of robustness iterations parameter: this many + * robustness iterations are done. + * <p> + * A sensible value is usually 0 (just the initial fit without any + * robustness iterations) to 4.</p> + */ + private final int robustnessIters; + /** + * If the median residual at a certain robustness iteration + * is less than this amount, no more iterations are done. + */ + private final double accuracy; + + /** + * Constructs a new {@link LoessInterpolator} + * with a bandwidth of {@link #DEFAULT_BANDWIDTH}, + * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations + * and an accuracy of {#link #DEFAULT_ACCURACY}. + * See {@link #LoessInterpolator(double, int, double)} for an explanation of + * the parameters. + */ + public LoessInterpolator() { + this.bandwidth = DEFAULT_BANDWIDTH; + this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS; + this.accuracy = DEFAULT_ACCURACY; + } + + /** + * Construct a new {@link LoessInterpolator} + * with given bandwidth and number of robustness iterations. + * <p> + * Calling this constructor is equivalent to calling {link {@link + * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth, + * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)} + * </p> + * + * @param bandwidth when computing the loess fit at + * a particular point, this fraction of source points closest + * to the current point is taken into account for computing + * a least-squares regression. + * A sensible value is usually 0.25 to 0.5, the default value is + * {@link #DEFAULT_BANDWIDTH}. + * @param robustnessIters This many robustness iterations are done. + * A sensible value is usually 0 (just the initial fit without any + * robustness iterations) to 4, the default value is + * {@link #DEFAULT_ROBUSTNESS_ITERS}. + + * @see #LoessInterpolator(double, int, double) + */ + public LoessInterpolator(double bandwidth, int robustnessIters) { + this(bandwidth, robustnessIters, DEFAULT_ACCURACY); + } + + /** + * Construct a new {@link LoessInterpolator} + * with given bandwidth, number of robustness iterations and accuracy. + * + * @param bandwidth when computing the loess fit at + * a particular point, this fraction of source points closest + * to the current point is taken into account for computing + * a least-squares regression. + * A sensible value is usually 0.25 to 0.5, the default value is + * {@link #DEFAULT_BANDWIDTH}. + * @param robustnessIters This many robustness iterations are done. + * A sensible value is usually 0 (just the initial fit without any + * robustness iterations) to 4, the default value is + * {@link #DEFAULT_ROBUSTNESS_ITERS}. + * @param accuracy If the median residual at a certain robustness iteration + * is less than this amount, no more iterations are done. + * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1]. + * @throws NotPositiveException if {@code robustnessIters} is negative. + * @see #LoessInterpolator(double, int) + * @since 2.1 + */ + public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) + throws OutOfRangeException, + NotPositiveException { + if (bandwidth < 0 || + bandwidth > 1) { + throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1); + } + this.bandwidth = bandwidth; + if (robustnessIters < 0) { + throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters); + } + this.robustnessIters = robustnessIters; + this.accuracy = accuracy; + } + + /** + * Compute an interpolating function by performing a loess fit + * on the data at the original abscissae and then building a cubic spline + * with a + * {@link org.apache.commons.math3.analysis.interpolation.SplineInterpolator} + * on the resulting fit. + * + * @param xval the arguments for the interpolation points + * @param yval the values for the interpolation points + * @return A cubic spline built upon a loess fit to the data at the original abscissae + * @throws NonMonotonicSequenceException if {@code xval} not sorted in + * strictly increasing order. + * @throws DimensionMismatchException if {@code xval} and {@code yval} have + * different sizes. + * @throws NoDataException if {@code xval} or {@code yval} has zero size. + * @throws NotFiniteNumberException if any of the arguments and values are + * not finite real numbers. + * @throws NumberIsTooSmallException if the bandwidth is too small to + * accomodate the size of the input data (i.e. the bandwidth must be + * larger than 2/n). + */ + public final PolynomialSplineFunction interpolate(final double[] xval, + final double[] yval) + throws NonMonotonicSequenceException, + DimensionMismatchException, + NoDataException, + NotFiniteNumberException, + NumberIsTooSmallException { + return new SplineInterpolator().interpolate(xval, smooth(xval, yval)); + } + + /** + * Compute a weighted loess fit on the data at the original abscissae. + * + * @param xval Arguments for the interpolation points. + * @param yval Values for the interpolation points. + * @param weights point weights: coefficients by which the robustness weight + * of a point is multiplied. + * @return the values of the loess fit at corresponding original abscissae. + * @throws NonMonotonicSequenceException if {@code xval} not sorted in + * strictly increasing order. + * @throws DimensionMismatchException if {@code xval} and {@code yval} have + * different sizes. + * @throws NoDataException if {@code xval} or {@code yval} has zero size. + * @throws NotFiniteNumberException if any of the arguments and values are + not finite real numbers. + * @throws NumberIsTooSmallException if the bandwidth is too small to + * accomodate the size of the input data (i.e. the bandwidth must be + * larger than 2/n). + * @since 2.1 + */ + public final double[] smooth(final double[] xval, final double[] yval, + final double[] weights) + throws NonMonotonicSequenceException, + DimensionMismatchException, + NoDataException, + NotFiniteNumberException, + NumberIsTooSmallException { + if (xval.length != yval.length) { + throw new DimensionMismatchException(xval.length, yval.length); + } + + final int n = xval.length; + + if (n == 0) { + throw new NoDataException(); + } + + checkAllFiniteReal(xval); + checkAllFiniteReal(yval); + checkAllFiniteReal(weights); + + MathArrays.checkOrder(xval); + + if (n == 1) { + return new double[]{yval[0]}; + } + + if (n == 2) { + return new double[]{yval[0], yval[1]}; + } + + int bandwidthInPoints = (int) (bandwidth * n); + + if (bandwidthInPoints < 2) { + throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH, + bandwidthInPoints, 2, true); + } + + final double[] res = new double[n]; + + final double[] residuals = new double[n]; + final double[] sortedResiduals = new double[n]; + + final double[] robustnessWeights = new double[n]; + + // Do an initial fit and 'robustnessIters' robustness iterations. + // This is equivalent to doing 'robustnessIters+1' robustness iterations + // starting with all robustness weights set to 1. + Arrays.fill(robustnessWeights, 1); + + for (int iter = 0; iter <= robustnessIters; ++iter) { + final int[] bandwidthInterval = {0, bandwidthInPoints - 1}; + // At each x, compute a local weighted linear regression + for (int i = 0; i < n; ++i) { + final double x = xval[i]; + + // Find out the interval of source points on which + // a regression is to be made. + if (i > 0) { + updateBandwidthInterval(xval, weights, i, bandwidthInterval); + } + + final int ileft = bandwidthInterval[0]; + final int iright = bandwidthInterval[1]; + + // Compute the point of the bandwidth interval that is + // farthest from x + final int edge; + if (xval[i] - xval[ileft] > xval[iright] - xval[i]) { + edge = ileft; + } else { + edge = iright; + } + + // Compute a least-squares linear fit weighted by + // the product of robustness weights and the tricube + // weight function. + // See http://en.wikipedia.org/wiki/Linear_regression + // (section "Univariate linear case") + // and http://en.wikipedia.org/wiki/Weighted_least_squares + // (section "Weighted least squares") + double sumWeights = 0; + double sumX = 0; + double sumXSquared = 0; + double sumY = 0; + double sumXY = 0; + double denom = FastMath.abs(1.0 / (xval[edge] - x)); + for (int k = ileft; k <= iright; ++k) { + final double xk = xval[k]; + final double yk = yval[k]; + final double dist = (k < i) ? x - xk : xk - x; + final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k]; + final double xkw = xk * w; + sumWeights += w; + sumX += xkw; + sumXSquared += xk * xkw; + sumY += yk * w; + sumXY += yk * xkw; + } + + final double meanX = sumX / sumWeights; + final double meanY = sumY / sumWeights; + final double meanXY = sumXY / sumWeights; + final double meanXSquared = sumXSquared / sumWeights; + + final double beta; + if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) { + beta = 0; + } else { + beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX); + } + + final double alpha = meanY - beta * meanX; + + res[i] = beta * x + alpha; + residuals[i] = FastMath.abs(yval[i] - res[i]); + } + + // No need to recompute the robustness weights at the last + // iteration, they won't be needed anymore + if (iter == robustnessIters) { + break; + } + + // Recompute the robustness weights. + + // Find the median residual. + // An arraycopy and a sort are completely tractable here, + // because the preceding loop is a lot more expensive + System.arraycopy(residuals, 0, sortedResiduals, 0, n); + Arrays.sort(sortedResiduals); + final double medianResidual = sortedResiduals[n / 2]; + + if (FastMath.abs(medianResidual) < accuracy) { + break; + } + + for (int i = 0; i < n; ++i) { + final double arg = residuals[i] / (6 * medianResidual); + if (arg >= 1) { + robustnessWeights[i] = 0; + } else { + final double w = 1 - arg * arg; + robustnessWeights[i] = w * w; + } + } + } + + return res; + } + + /** + * Compute a loess fit on the data at the original abscissae. + * + * @param xval the arguments for the interpolation points + * @param yval the values for the interpolation points + * @return values of the loess fit at corresponding original abscissae + * @throws NonMonotonicSequenceException if {@code xval} not sorted in + * strictly increasing order. + * @throws DimensionMismatchException if {@code xval} and {@code yval} have + * different sizes. + * @throws NoDataException if {@code xval} or {@code yval} has zero size. + * @throws NotFiniteNumberException if any of the arguments and values are + * not finite real numbers. + * @throws NumberIsTooSmallException if the bandwidth is too small to + * accomodate the size of the input data (i.e. the bandwidth must be + * larger than 2/n). + */ + public final double[] smooth(final double[] xval, final double[] yval) + throws NonMonotonicSequenceException, + DimensionMismatchException, + NoDataException, + NotFiniteNumberException, + NumberIsTooSmallException { + if (xval.length != yval.length) { + throw new DimensionMismatchException(xval.length, yval.length); + } + + final double[] unitWeights = new double[xval.length]; + Arrays.fill(unitWeights, 1.0); + + return smooth(xval, yval, unitWeights); + } + + /** + * Given an index interval into xval that embraces a certain number of + * points closest to {@code xval[i-1]}, update the interval so that it + * embraces the same number of points closest to {@code xval[i]}, + * ignoring zero weights. + * + * @param xval Arguments array. + * @param weights Weights array. + * @param i Index around which the new interval should be computed. + * @param bandwidthInterval a two-element array {left, right} such that: + * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])} + * and + * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}. + * The array will be updated. + */ + private static void updateBandwidthInterval(final double[] xval, final double[] weights, + final int i, + final int[] bandwidthInterval) { + final int left = bandwidthInterval[0]; + final int right = bandwidthInterval[1]; + + // The right edge should be adjusted if the next point to the right + // is closer to xval[i] than the leftmost point of the current interval + int nextRight = nextNonzero(weights, right); + if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) { + int nextLeft = nextNonzero(weights, bandwidthInterval[0]); + bandwidthInterval[0] = nextLeft; + bandwidthInterval[1] = nextRight; + } + } + + /** + * Return the smallest index {@code j} such that + * {@code j > i && (j == weights.length || weights[j] != 0)}. + * + * @param weights Weights array. + * @param i Index from which to start search. + * @return the smallest compliant index. + */ + private static int nextNonzero(final double[] weights, final int i) { + int j = i + 1; + while(j < weights.length && weights[j] == 0) { + ++j; + } + return j; + } + + /** + * Compute the + * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a> + * weight function + * + * @param x Argument. + * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| < 1, 0 otherwise. + */ + private static double tricube(final double x) { + final double absX = FastMath.abs(x); + if (absX >= 1.0) { + return 0.0; + } + final double tmp = 1 - absX * absX * absX; + return tmp * tmp * tmp; + } + + /** + * Check that all elements of an array are finite real numbers. + * + * @param values Values array. + * @throws org.apache.commons.math3.exception.NotFiniteNumberException + * if one of the values is not a finite real number. + */ + private static void checkAllFiniteReal(final double[] values) { + for (int i = 0; i < values.length; i++) { + MathUtils.checkFinite(values[i]); + } + } +} |