diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java')
-rw-r--r-- | src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java new file mode 100644 index 0000000..22a59e8 --- /dev/null +++ b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.stat.regression; + +import org.apache.commons.math.linear.Array2DRowRealMatrix; +import org.apache.commons.math.linear.LUDecompositionImpl; +import org.apache.commons.math.linear.QRDecomposition; +import org.apache.commons.math.linear.QRDecompositionImpl; +import org.apache.commons.math.linear.RealMatrix; +import org.apache.commons.math.linear.RealVector; +import org.apache.commons.math.stat.StatUtils; +import org.apache.commons.math.stat.descriptive.moment.SecondMoment; + +/** + * <p>Implements ordinary least squares (OLS) to estimate the parameters of a + * multiple linear regression model.</p> + * + * <p>The regression coefficients, <code>b</code>, satisfy the normal equations: + * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p> + * + * <p>To solve the normal equations, this implementation uses QR decomposition + * of the <code>X</code> matrix. (See {@link QRDecompositionImpl} for details on the + * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i> + * has rows corresponding to sample observations and columns corresponding to independent + * variables. When the model is estimated using an intercept term (i.e. when + * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code> + * matrix includes an initial column identically equal to 1. We solve the normal equations + * as follows: + * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y + * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y + * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y + * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y + * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y + * R b = Q<sup>T</sup> y </code></pre></p> + * + * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p> + * + * @version $Revision: 1073464 $ $Date: 2011-02-22 20:35:02 +0100 (mar. 22 févr. 2011) $ + * @since 2.0 + */ +public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { + + /** Cached QR decomposition of X matrix */ + private QRDecomposition qr = null; + + /** + * Loads model x and y sample data, overriding any previous sample. + * + * Computes and caches QR decomposition of the X matrix. + * @param y the [n,1] array representing the y sample + * @param x the [n,k] array representing the x sample + * @throws IllegalArgumentException if the x and y array data are not + * compatible for the regression + */ + public void newSampleData(double[] y, double[][] x) { + validateSampleData(x, y); + newYSampleData(y); + newXSampleData(x); + } + + /** + * {@inheritDoc} + * <p>This implementation computes and caches the QR decomposition of the X matrix.</p> + */ + @Override + public void newSampleData(double[] data, int nobs, int nvars) { + super.newSampleData(data, nobs, nvars); + qr = new QRDecompositionImpl(X); + } + + /** + * <p>Compute the "hat" matrix. + * </p> + * <p>The hat matrix is defined in terms of the design matrix X + * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> + * </p> + * <p>The implementation here uses the QR decomposition to compute the + * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the + * p-dimensional identity matrix augmented by 0's. This computational + * formula is from "The Hat Matrix in Regression and ANOVA", + * David C. Hoaglin and Roy E. Welsch, + * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. + * + * @return the hat matrix + */ + public RealMatrix calculateHat() { + // Create augmented identity matrix + RealMatrix Q = qr.getQ(); + final int p = qr.getR().getColumnDimension(); + final int n = Q.getColumnDimension(); + Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); + double[][] augIData = augI.getDataRef(); + for (int i = 0; i < n; i++) { + for (int j =0; j < n; j++) { + if (i == j && i < p) { + augIData[i][j] = 1d; + } else { + augIData[i][j] = 0d; + } + } + } + + // Compute and return Hat matrix + return Q.multiply(augI).multiply(Q.transpose()); + } + + /** + * <p>Returns the sum of squared deviations of Y from its mean.</p> + * + * <p>If the model has no intercept term, <code>0</code> is used for the + * mean of Y - i.e., what is returned is the sum of the squared Y values.</p> + * + * <p>The value returned by this method is the SSTO value used in + * the {@link #calculateRSquared() R-squared} computation.</p> + * + * @return SSTO - the total sum of squares + * @see #isNoIntercept() + * @since 2.2 + */ + public double calculateTotalSumOfSquares() { + if (isNoIntercept()) { + return StatUtils.sumSq(Y.getData()); + } else { + return new SecondMoment().evaluate(Y.getData()); + } + } + + /** + * Returns the sum of squared residuals. + * + * @return residual sum of squares + * @since 2.2 + */ + public double calculateResidualSumOfSquares() { + final RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals); + } + + /** + * Returns the R-Squared statistic, defined by the formula <pre> + * R<sup>2</sup> = 1 - SSR / SSTO + * </pre> + * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} + * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} + * + * @return R-square statistic + * @since 2.2 + */ + public double calculateRSquared() { + return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); + } + + /** + * <p>Returns the adjusted R-squared statistic, defined by the formula <pre> + * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] + * </pre> + * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, + * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number + * of observations and p is the number of parameters estimated (including the intercept).</p> + * + * <p>If the regression is estimated without an intercept term, what is returned is <pre> + * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code> + * </pre></p> + * + * @return adjusted R-Squared statistic + * @see #isNoIntercept() + * @since 2.2 + */ + public double calculateAdjustedRSquared() { + final double n = X.getRowDimension(); + if (isNoIntercept()) { + return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension())); + } else { + return 1 - (calculateResidualSumOfSquares() * (n - 1)) / + (calculateTotalSumOfSquares() * (n - X.getColumnDimension())); + } + } + + /** + * {@inheritDoc} + * <p>This implementation computes and caches the QR decomposition of the X matrix + * once it is successfully loaded.</p> + */ + @Override + protected void newXSampleData(double[][] x) { + super.newXSampleData(x); + qr = new QRDecompositionImpl(X); + } + + /** + * Calculates the regression coefficients using OLS. + * + * @return beta + */ + @Override + protected RealVector calculateBeta() { + return qr.getSolver().solve(Y); + } + + /** + * <p>Calculates the variance-covariance matrix of the regression parameters. + * </p> + * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> + * </p> + * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> + * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of + * R included, where p = the length of the beta vector.</p> + * + * @return The beta variance-covariance matrix + */ + @Override + protected RealMatrix calculateBetaVariance() { + int p = X.getColumnDimension(); + RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); + RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse(); + return Rinv.multiply(Rinv.transpose()); + } + +} |