summaryrefslogtreecommitdiff
path: root/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java')
-rw-r--r--src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java233
1 files changed, 233 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
new file mode 100644
index 0000000..22a59e8
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math.stat.regression;
+
+import org.apache.commons.math.linear.Array2DRowRealMatrix;
+import org.apache.commons.math.linear.LUDecompositionImpl;
+import org.apache.commons.math.linear.QRDecomposition;
+import org.apache.commons.math.linear.QRDecompositionImpl;
+import org.apache.commons.math.linear.RealMatrix;
+import org.apache.commons.math.linear.RealVector;
+import org.apache.commons.math.stat.StatUtils;
+import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
+
+/**
+ * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
+ * multiple linear regression model.</p>
+ *
+ * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
+ * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
+ *
+ * <p>To solve the normal equations, this implementation uses QR decomposition
+ * of the <code>X</code> matrix. (See {@link QRDecompositionImpl} for details on the
+ * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
+ * has rows corresponding to sample observations and columns corresponding to independent
+ * variables. When the model is estimated using an intercept term (i.e. when
+ * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
+ * matrix includes an initial column identically equal to 1. We solve the normal equations
+ * as follows:
+ * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
+ * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
+ * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
+ * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
+ * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
+ * R b = Q<sup>T</sup> y </code></pre></p>
+ *
+ * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
+ *
+ * @version $Revision: 1073464 $ $Date: 2011-02-22 20:35:02 +0100 (mar. 22 févr. 2011) $
+ * @since 2.0
+ */
+public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
+
+ /** Cached QR decomposition of X matrix */
+ private QRDecomposition qr = null;
+
+ /**
+ * Loads model x and y sample data, overriding any previous sample.
+ *
+ * Computes and caches QR decomposition of the X matrix.
+ * @param y the [n,1] array representing the y sample
+ * @param x the [n,k] array representing the x sample
+ * @throws IllegalArgumentException if the x and y array data are not
+ * compatible for the regression
+ */
+ public void newSampleData(double[] y, double[][] x) {
+ validateSampleData(x, y);
+ newYSampleData(y);
+ newXSampleData(x);
+ }
+
+ /**
+ * {@inheritDoc}
+ * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
+ */
+ @Override
+ public void newSampleData(double[] data, int nobs, int nvars) {
+ super.newSampleData(data, nobs, nvars);
+ qr = new QRDecompositionImpl(X);
+ }
+
+ /**
+ * <p>Compute the "hat" matrix.
+ * </p>
+ * <p>The hat matrix is defined in terms of the design matrix X
+ * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
+ * </p>
+ * <p>The implementation here uses the QR decomposition to compute the
+ * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
+ * p-dimensional identity matrix augmented by 0's. This computational
+ * formula is from "The Hat Matrix in Regression and ANOVA",
+ * David C. Hoaglin and Roy E. Welsch,
+ * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
+ *
+ * @return the hat matrix
+ */
+ public RealMatrix calculateHat() {
+ // Create augmented identity matrix
+ RealMatrix Q = qr.getQ();
+ final int p = qr.getR().getColumnDimension();
+ final int n = Q.getColumnDimension();
+ Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
+ double[][] augIData = augI.getDataRef();
+ for (int i = 0; i < n; i++) {
+ for (int j =0; j < n; j++) {
+ if (i == j && i < p) {
+ augIData[i][j] = 1d;
+ } else {
+ augIData[i][j] = 0d;
+ }
+ }
+ }
+
+ // Compute and return Hat matrix
+ return Q.multiply(augI).multiply(Q.transpose());
+ }
+
+ /**
+ * <p>Returns the sum of squared deviations of Y from its mean.</p>
+ *
+ * <p>If the model has no intercept term, <code>0</code> is used for the
+ * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
+ *
+ * <p>The value returned by this method is the SSTO value used in
+ * the {@link #calculateRSquared() R-squared} computation.</p>
+ *
+ * @return SSTO - the total sum of squares
+ * @see #isNoIntercept()
+ * @since 2.2
+ */
+ public double calculateTotalSumOfSquares() {
+ if (isNoIntercept()) {
+ return StatUtils.sumSq(Y.getData());
+ } else {
+ return new SecondMoment().evaluate(Y.getData());
+ }
+ }
+
+ /**
+ * Returns the sum of squared residuals.
+ *
+ * @return residual sum of squares
+ * @since 2.2
+ */
+ public double calculateResidualSumOfSquares() {
+ final RealVector residuals = calculateResiduals();
+ return residuals.dotProduct(residuals);
+ }
+
+ /**
+ * Returns the R-Squared statistic, defined by the formula <pre>
+ * R<sup>2</sup> = 1 - SSR / SSTO
+ * </pre>
+ * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
+ * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
+ *
+ * @return R-square statistic
+ * @since 2.2
+ */
+ public double calculateRSquared() {
+ return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
+ }
+
+ /**
+ * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
+ * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
+ * </pre>
+ * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
+ * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
+ * of observations and p is the number of parameters estimated (including the intercept).</p>
+ *
+ * <p>If the regression is estimated without an intercept term, what is returned is <pre>
+ * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
+ * </pre></p>
+ *
+ * @return adjusted R-Squared statistic
+ * @see #isNoIntercept()
+ * @since 2.2
+ */
+ public double calculateAdjustedRSquared() {
+ final double n = X.getRowDimension();
+ if (isNoIntercept()) {
+ return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
+ } else {
+ return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
+ (calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ * <p>This implementation computes and caches the QR decomposition of the X matrix
+ * once it is successfully loaded.</p>
+ */
+ @Override
+ protected void newXSampleData(double[][] x) {
+ super.newXSampleData(x);
+ qr = new QRDecompositionImpl(X);
+ }
+
+ /**
+ * Calculates the regression coefficients using OLS.
+ *
+ * @return beta
+ */
+ @Override
+ protected RealVector calculateBeta() {
+ return qr.getSolver().solve(Y);
+ }
+
+ /**
+ * <p>Calculates the variance-covariance matrix of the regression parameters.
+ * </p>
+ * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
+ * </p>
+ * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
+ * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
+ * R included, where p = the length of the beta vector.</p>
+ *
+ * @return The beta variance-covariance matrix
+ */
+ @Override
+ protected RealMatrix calculateBetaVariance() {
+ int p = X.getColumnDimension();
+ RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
+ RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
+ return Rinv.multiply(Rinv.transpose());
+ }
+
+}