diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/stat')
92 files changed, 25064 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/stat/Frequency.java b/src/main/java/org/apache/commons/math3/stat/Frequency.java new file mode 100644 index 0000000..276382c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/Frequency.java @@ -0,0 +1,664 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.MathUtils; + +import java.io.Serializable; +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * Maintains a frequency distribution. + * + * <p>Accepts int, long, char or Comparable values. New values added must be comparable to those + * that have been added, otherwise the add method will throw an IllegalArgumentException. + * + * <p>Integer values (int, long, Integer, Long) are not distinguished by type -- i.e. <code> + * addValue(Long.valueOf(2)), addValue(2), addValue(2l)</code> all have the same effect (similarly + * for arguments to <code>getCount,</code> etc.). + * + * <p>NOTE: byte and short values will be implicitly converted to int values by the compiler, thus + * there are no explicit overloaded methods for these primitive types. + * + * <p>char values are converted by <code>addValue</code> to Character instances. As such, these + * values are not comparable to integral values, so attempts to combine integral types with chars in + * a frequency distribution will fail. + * + * <p>Float is not coerced to Double. Since they are not Comparable with each other the user must do + * any necessary coercion. Float.NaN and Double.NaN are not treated specially; they may occur in + * input and will occur in output if appropriate. </b> + * + * <p>The values are ordered using the default (natural order), unless a <code>Comparator</code> is + * supplied in the constructor. + */ +public class Frequency implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -3845586908418844111L; + + /** underlying collection */ + private final SortedMap<Comparable<?>, Long> freqTable; + + /** Default constructor. */ + public Frequency() { + freqTable = new TreeMap<Comparable<?>, Long>(); + } + + /** + * Constructor allowing values Comparator to be specified. + * + * @param comparator Comparator used to order values + */ + @SuppressWarnings("unchecked") // TODO is the cast OK? + public Frequency(Comparator<?> comparator) { + freqTable = + new TreeMap<Comparable<?>, Long>((Comparator<? super Comparable<?>>) comparator); + } + + /** + * Return a string representation of this frequency distribution. + * + * @return a string representation. + */ + @Override + public String toString() { + NumberFormat nf = NumberFormat.getPercentInstance(); + StringBuilder outBuffer = new StringBuilder(); + outBuffer.append("Value \t Freq. \t Pct. \t Cum Pct. \n"); + Iterator<Comparable<?>> iter = freqTable.keySet().iterator(); + while (iter.hasNext()) { + Comparable<?> value = iter.next(); + outBuffer.append(value); + outBuffer.append('\t'); + outBuffer.append(getCount(value)); + outBuffer.append('\t'); + outBuffer.append(nf.format(getPct(value))); + outBuffer.append('\t'); + outBuffer.append(nf.format(getCumPct(value))); + outBuffer.append('\n'); + } + return outBuffer.toString(); + } + + /** + * Adds 1 to the frequency count for v. + * + * <p>If other objects have already been added to this Frequency, v must be comparable to those + * that have already been added. + * + * @param v the value to add. + * @throws MathIllegalArgumentException if <code>v</code> is not comparable with previous + * entries + */ + public void addValue(Comparable<?> v) throws MathIllegalArgumentException { + incrementValue(v, 1); + } + + /** + * Adds 1 to the frequency count for v. + * + * @param v the value to add. + * @throws MathIllegalArgumentException if the table contains entries not comparable to Long + */ + public void addValue(int v) throws MathIllegalArgumentException { + addValue(Long.valueOf(v)); + } + + /** + * Adds 1 to the frequency count for v. + * + * @param v the value to add. + * @throws MathIllegalArgumentException if the table contains entries not comparable to Long + */ + public void addValue(long v) throws MathIllegalArgumentException { + addValue(Long.valueOf(v)); + } + + /** + * Adds 1 to the frequency count for v. + * + * @param v the value to add. + * @throws MathIllegalArgumentException if the table contains entries not comparable to Char + */ + public void addValue(char v) throws MathIllegalArgumentException { + addValue(Character.valueOf(v)); + } + + /** + * Increments the frequency count for v. + * + * <p>If other objects have already been added to this Frequency, v must be comparable to those + * that have already been added. + * + * @param v the value to add. + * @param increment the amount by which the value should be incremented + * @throws MathIllegalArgumentException if <code>v</code> is not comparable with previous + * entries + * @since 3.1 + */ + public void incrementValue(Comparable<?> v, long increment) + throws MathIllegalArgumentException { + Comparable<?> obj = v; + if (v instanceof Integer) { + obj = Long.valueOf(((Integer) v).longValue()); + } + try { + Long count = freqTable.get(obj); + if (count == null) { + freqTable.put(obj, Long.valueOf(increment)); + } else { + freqTable.put(obj, Long.valueOf(count.longValue() + increment)); + } + } catch (ClassCastException ex) { + // TreeMap will throw ClassCastException if v is not comparable + throw new MathIllegalArgumentException( + LocalizedFormats.INSTANCES_NOT_COMPARABLE_TO_EXISTING_VALUES, + v.getClass().getName()); + } + } + + /** + * Increments the frequency count for v. + * + * <p>If other objects have already been added to this Frequency, v must be comparable to those + * that have already been added. + * + * @param v the value to add. + * @param increment the amount by which the value should be incremented + * @throws MathIllegalArgumentException if the table contains entries not comparable to Long + * @since 3.3 + */ + public void incrementValue(int v, long increment) throws MathIllegalArgumentException { + incrementValue(Long.valueOf(v), increment); + } + + /** + * Increments the frequency count for v. + * + * <p>If other objects have already been added to this Frequency, v must be comparable to those + * that have already been added. + * + * @param v the value to add. + * @param increment the amount by which the value should be incremented + * @throws MathIllegalArgumentException if the table contains entries not comparable to Long + * @since 3.3 + */ + public void incrementValue(long v, long increment) throws MathIllegalArgumentException { + incrementValue(Long.valueOf(v), increment); + } + + /** + * Increments the frequency count for v. + * + * <p>If other objects have already been added to this Frequency, v must be comparable to those + * that have already been added. + * + * @param v the value to add. + * @param increment the amount by which the value should be incremented + * @throws MathIllegalArgumentException if the table contains entries not comparable to Char + * @since 3.3 + */ + public void incrementValue(char v, long increment) throws MathIllegalArgumentException { + incrementValue(Character.valueOf(v), increment); + } + + /** Clears the frequency table */ + public void clear() { + freqTable.clear(); + } + + /** + * Returns an Iterator over the set of values that have been added. + * + * <p>If added values are integral (i.e., integers, longs, Integers, or Longs), they are + * converted to Longs when they are added, so the objects returned by the Iterator will in this + * case be Longs. + * + * @return values Iterator + */ + public Iterator<Comparable<?>> valuesIterator() { + return freqTable.keySet().iterator(); + } + + /** + * Return an Iterator over the set of keys and values that have been added. Using the entry set + * to iterate is more efficient in the case where you need to access respective counts as well + * as values, since it doesn't require a "get" for every key...the value is provided in the + * Map.Entry. + * + * <p>If added values are integral (i.e., integers, longs, Integers, or Longs), they are + * converted to Longs when they are added, so the values of the map entries returned by the + * Iterator will in this case be Longs. + * + * @return entry set Iterator + * @since 3.1 + */ + public Iterator<Map.Entry<Comparable<?>, Long>> entrySetIterator() { + return freqTable.entrySet().iterator(); + } + + // ------------------------------------------------------------------------- + + /** + * Returns the sum of all frequencies. + * + * @return the total frequency count. + */ + public long getSumFreq() { + long result = 0; + Iterator<Long> iterator = freqTable.values().iterator(); + while (iterator.hasNext()) { + result += iterator.next().longValue(); + } + return result; + } + + /** + * Returns the number of values equal to v. Returns 0 if the value is not comparable. + * + * @param v the value to lookup. + * @return the frequency of v. + */ + public long getCount(Comparable<?> v) { + if (v instanceof Integer) { + return getCount(((Integer) v).longValue()); + } + long result = 0; + try { + Long count = freqTable.get(v); + if (count != null) { + result = count.longValue(); + } + } catch (ClassCastException ex) { // NOPMD + // ignore and return 0 -- ClassCastException will be thrown if value is not comparable + } + return result; + } + + /** + * Returns the number of values equal to v. + * + * @param v the value to lookup. + * @return the frequency of v. + */ + public long getCount(int v) { + return getCount(Long.valueOf(v)); + } + + /** + * Returns the number of values equal to v. + * + * @param v the value to lookup. + * @return the frequency of v. + */ + public long getCount(long v) { + return getCount(Long.valueOf(v)); + } + + /** + * Returns the number of values equal to v. + * + * @param v the value to lookup. + * @return the frequency of v. + */ + public long getCount(char v) { + return getCount(Character.valueOf(v)); + } + + /** + * Returns the number of values in the frequency table. + * + * @return the number of unique values that have been added to the frequency table. + * @see #valuesIterator() + */ + public int getUniqueCount() { + return freqTable.keySet().size(); + } + + /** + * Returns the percentage of values that are equal to v (as a proportion between 0 and 1). + * + * <p>Returns <code>Double.NaN</code> if no values have been added. Returns 0 if at least one + * value has been added, but v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public double getPct(Comparable<?> v) { + final long sumFreq = getSumFreq(); + if (sumFreq == 0) { + return Double.NaN; + } + return (double) getCount(v) / (double) sumFreq; + } + + /** + * Returns the percentage of values that are equal to v (as a proportion between 0 and 1). + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public double getPct(int v) { + return getPct(Long.valueOf(v)); + } + + /** + * Returns the percentage of values that are equal to v (as a proportion between 0 and 1). + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public double getPct(long v) { + return getPct(Long.valueOf(v)); + } + + /** + * Returns the percentage of values that are equal to v (as a proportion between 0 and 1). + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public double getPct(char v) { + return getPct(Character.valueOf(v)); + } + + // ----------------------------------------------------------------------------------------- + + /** + * Returns the cumulative frequency of values less than or equal to v. + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup. + * @return the proportion of values equal to v + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public long getCumFreq(Comparable<?> v) { + if (getSumFreq() == 0) { + return 0; + } + if (v instanceof Integer) { + return getCumFreq(((Integer) v).longValue()); + } + Comparator<Comparable<?>> c = (Comparator<Comparable<?>>) freqTable.comparator(); + if (c == null) { + c = new NaturalComparator(); + } + long result = 0; + + try { + Long value = freqTable.get(v); + if (value != null) { + result = value.longValue(); + } + } catch (ClassCastException ex) { + return result; // v is not comparable + } + + if (c.compare(v, freqTable.firstKey()) < 0) { + return 0; // v is comparable, but less than first value + } + + if (c.compare(v, freqTable.lastKey()) >= 0) { + return getSumFreq(); // v is comparable, but greater than the last value + } + + Iterator<Comparable<?>> values = valuesIterator(); + while (values.hasNext()) { + Comparable<?> nextValue = values.next(); + if (c.compare(v, nextValue) > 0) { + result += getCount(nextValue); + } else { + return result; + } + } + return result; + } + + /** + * Returns the cumulative frequency of values less than or equal to v. + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public long getCumFreq(int v) { + return getCumFreq(Long.valueOf(v)); + } + + /** + * Returns the cumulative frequency of values less than or equal to v. + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public long getCumFreq(long v) { + return getCumFreq(Long.valueOf(v)); + } + + /** + * Returns the cumulative frequency of values less than or equal to v. + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values equal to v + */ + public long getCumFreq(char v) { + return getCumFreq(Character.valueOf(v)); + } + + // ---------------------------------------------------------------------------------------------- + + /** + * Returns the cumulative percentage of values less than or equal to v (as a proportion between + * 0 and 1). + * + * <p>Returns <code>Double.NaN</code> if no values have been added. Returns 0 if at least one + * value has been added, but v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values less than or equal to v + */ + public double getCumPct(Comparable<?> v) { + final long sumFreq = getSumFreq(); + if (sumFreq == 0) { + return Double.NaN; + } + return (double) getCumFreq(v) / (double) sumFreq; + } + + /** + * Returns the cumulative percentage of values less than or equal to v (as a proportion between + * 0 and 1). + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values less than or equal to v + */ + public double getCumPct(int v) { + return getCumPct(Long.valueOf(v)); + } + + /** + * Returns the cumulative percentage of values less than or equal to v (as a proportion between + * 0 and 1). + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values less than or equal to v + */ + public double getCumPct(long v) { + return getCumPct(Long.valueOf(v)); + } + + /** + * Returns the cumulative percentage of values less than or equal to v (as a proportion between + * 0 and 1). + * + * <p>Returns 0 if v is not comparable to the values set. + * + * @param v the value to lookup + * @return the proportion of values less than or equal to v + */ + public double getCumPct(char v) { + return getCumPct(Character.valueOf(v)); + } + + /** + * Returns the mode value(s) in comparator order. + * + * @return a list containing the value(s) which appear most often. + * @since 3.3 + */ + public List<Comparable<?>> getMode() { + long mostPopular = 0; // frequencies are always positive + + // Get the max count first, so we avoid having to recreate the List each time + for (Long l : freqTable.values()) { + long frequency = l.longValue(); + if (frequency > mostPopular) { + mostPopular = frequency; + } + } + + List<Comparable<?>> modeList = new ArrayList<Comparable<?>>(); + for (Entry<Comparable<?>, Long> ent : freqTable.entrySet()) { + long frequency = ent.getValue().longValue(); + if (frequency == mostPopular) { + modeList.add(ent.getKey()); + } + } + return modeList; + } + + // ---------------------------------------------------------------------------------------------- + + /** + * Merge another Frequency object's counts into this instance. This Frequency's counts will be + * incremented (or set when not already set) by the counts represented by other. + * + * @param other the other {@link Frequency} object to be merged + * @throws NullArgumentException if {@code other} is null + * @since 3.1 + */ + public void merge(final Frequency other) throws NullArgumentException { + MathUtils.checkNotNull(other, LocalizedFormats.NULL_NOT_ALLOWED); + + final Iterator<Map.Entry<Comparable<?>, Long>> iter = other.entrySetIterator(); + while (iter.hasNext()) { + final Map.Entry<Comparable<?>, Long> entry = iter.next(); + incrementValue(entry.getKey(), entry.getValue().longValue()); + } + } + + /** + * Merge a {@link Collection} of {@link Frequency} objects into this instance. This Frequency's + * counts will be incremented (or set when not already set) by the counts represented by each of + * the others. + * + * @param others the other {@link Frequency} objects to be merged + * @throws NullArgumentException if the collection is null + * @since 3.1 + */ + public void merge(final Collection<Frequency> others) throws NullArgumentException { + MathUtils.checkNotNull(others, LocalizedFormats.NULL_NOT_ALLOWED); + + for (final Frequency freq : others) { + merge(freq); + } + } + + // ---------------------------------------------------------------------------------------------- + + /** + * A Comparator that compares comparable objects using the natural order. Copied from Commons + * Collections ComparableComparator. + * + * @param <T> the type of the objects compared + */ + private static class NaturalComparator<T extends Comparable<T>> + implements Comparator<Comparable<T>>, Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -3852193713161395148L; + + /** + * Compare the two {@link Comparable Comparable} arguments. This method is equivalent to: + * + * <pre>(({@link Comparable Comparable})o1).{@link Comparable#compareTo compareTo}(o2)</pre> + * + * @param o1 the first object + * @param o2 the second object + * @return result of comparison + * @throws NullPointerException when <i>o1</i> is <code>null</code>, or when <code> + * ((Comparable)o1).compareTo(o2)</code> does + * @throws ClassCastException when <i>o1</i> is not a {@link Comparable Comparable}, or when + * <code>((Comparable)o1).compareTo(o2)</code> does + */ + @SuppressWarnings("unchecked") // cast to (T) may throw ClassCastException, see Javadoc + public int compare(Comparable<T> o1, Comparable<T> o2) { + return o1.compareTo((T) o2); + } + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((freqTable == null) ? 0 : freqTable.hashCode()); + return result; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Frequency)) { + return false; + } + Frequency other = (Frequency) obj; + if (freqTable == null) { + if (other.freqTable != null) { + return false; + } + } else if (!freqTable.equals(other.freqTable)) { + return false; + } + return true; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/StatUtils.java b/src/main/java/org/apache/commons/math3/stat/StatUtils.java new file mode 100644 index 0000000..31d75e8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/StatUtils.java @@ -0,0 +1,852 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; +import org.apache.commons.math3.stat.descriptive.UnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.moment.GeometricMean; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.stat.descriptive.rank.Max; +import org.apache.commons.math3.stat.descriptive.rank.Min; +import org.apache.commons.math3.stat.descriptive.rank.Percentile; +import org.apache.commons.math3.stat.descriptive.summary.Product; +import org.apache.commons.math3.stat.descriptive.summary.Sum; +import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs; +import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares; + +import java.util.List; + +/** + * StatUtils provides static methods for computing statistics based on data stored in double[] + * arrays. + */ +public final class StatUtils { + + /** sum */ + private static final UnivariateStatistic SUM = new Sum(); + + /** sumSq */ + private static final UnivariateStatistic SUM_OF_SQUARES = new SumOfSquares(); + + /** prod */ + private static final UnivariateStatistic PRODUCT = new Product(); + + /** sumLog */ + private static final UnivariateStatistic SUM_OF_LOGS = new SumOfLogs(); + + /** min */ + private static final UnivariateStatistic MIN = new Min(); + + /** max */ + private static final UnivariateStatistic MAX = new Max(); + + /** mean */ + private static final UnivariateStatistic MEAN = new Mean(); + + /** variance */ + private static final Variance VARIANCE = new Variance(); + + /** percentile */ + private static final Percentile PERCENTILE = new Percentile(); + + /** geometric mean */ + private static final GeometricMean GEOMETRIC_MEAN = new GeometricMean(); + + /** Private Constructor */ + private StatUtils() {} + + /** + * Returns the sum of the values in the input array, or <code>Double.NaN</code> if the array is + * empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the input array is null. + * + * @param values array of values to sum + * @return the sum of the values or <code>Double.NaN</code> if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double sum(final double[] values) throws MathIllegalArgumentException { + return SUM.evaluate(values); + } + + /** + * Returns the sum of the entries in the specified portion of the input array, or <code> + * Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double sum(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return SUM.evaluate(values, begin, length); + } + + /** + * Returns the sum of the squares of the entries in the input array, or <code>Double.NaN</code> + * if the array is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * @param values input array + * @return the sum of the squared values or <code>Double.NaN</code> if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double sumSq(final double[] values) throws MathIllegalArgumentException { + return SUM_OF_SQUARES.evaluate(values); + } + + /** + * Returns the sum of the squares of the entries in the specified portion of the input array, or + * <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the squares of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double sumSq(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return SUM_OF_SQUARES.evaluate(values, begin, length); + } + + /** + * Returns the product of the entries in the input array, or <code>Double.NaN</code> if the + * array is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @return the product of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double product(final double[] values) throws MathIllegalArgumentException { + return PRODUCT.evaluate(values); + } + + /** + * Returns the product of the entries in the specified portion of the input array, or <code> + * Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the product of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double product(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return PRODUCT.evaluate(values, begin, length); + } + + /** + * Returns the sum of the natural logs of the entries in the input array, or <code>Double.NaN + * </code> if the array is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.summary.SumOfLogs}. + * + * @param values the input array + * @return the sum of the natural logs of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double sumLog(final double[] values) throws MathIllegalArgumentException { + return SUM_OF_LOGS.evaluate(values); + } + + /** + * Returns the sum of the natural logs of the entries in the specified portion of the input + * array, or <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.summary.SumOfLogs}. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the natural logs of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double sumLog(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return SUM_OF_LOGS.evaluate(values, begin, length); + } + + /** + * Returns the arithmetic mean of the entries in the input array, or <code>Double.NaN</code> if + * the array is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Mean} for details on the + * computing algorithm. + * + * @param values the input array + * @return the mean of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double mean(final double[] values) throws MathIllegalArgumentException { + return MEAN.evaluate(values); + } + + /** + * Returns the arithmetic mean of the entries in the specified portion of the input array, or + * <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Mean} for details on the + * computing algorithm. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the mean of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double mean(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return MEAN.evaluate(values, begin, length); + } + + /** + * Returns the geometric mean of the entries in the input array, or <code>Double.NaN</code> if + * the array is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.GeometricMean} for details on + * the computing algorithm. + * + * @param values the input array + * @return the geometric mean of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double geometricMean(final double[] values) throws MathIllegalArgumentException { + return GEOMETRIC_MEAN.evaluate(values); + } + + /** + * Returns the geometric mean of the entries in the specified portion of the input array, or + * <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>IllegalArgumentException</code> if the array is null. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.GeometricMean} for details on + * the computing algorithm. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the geometric mean of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double geometricMean(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return GEOMETRIC_MEAN.evaluate(values, begin, length); + } + + /** + * Returns the variance of the entries in the input array, or <code>Double.NaN</code> if the + * array is empty. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in the + * denominator). Use {@link #populationVariance(double[])} for the non-bias-corrected population + * variance. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @return the variance of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double variance(final double[] values) throws MathIllegalArgumentException { + return VARIANCE.evaluate(values); + } + + /** + * Returns the variance of the entries in the specified portion of the input array, or <code> + * Double.NaN</code> if the designated subarray is empty. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in the + * denominator). Use {@link #populationVariance(double[], int, int)} for the non-bias-corrected + * population variance. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double variance(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return VARIANCE.evaluate(values, begin, length); + } + + /** + * Returns the variance of the entries in the specified portion of the input array, using the + * precomputed mean value. Returns <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in the + * denominator). Use {@link #populationVariance(double[], double, int, int)} for the + * non-bias-corrected population variance. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>The formula used assumes that the supplied mean value is the arithmetic mean of the sample + * data, not a known population parameter. This method is supplied only to save computation when + * the mean has already been computed. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * @param values the input array + * @param mean the precomputed mean value + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double variance( + final double[] values, final double mean, final int begin, final int length) + throws MathIllegalArgumentException { + return VARIANCE.evaluate(values, mean, begin, length); + } + + /** + * Returns the variance of the entries in the input array, using the precomputed mean value. + * Returns <code>Double.NaN</code> if the array is empty. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in the + * denominator). Use {@link #populationVariance(double[], double)} for the non-bias-corrected + * population variance. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>The formula used assumes that the supplied mean value is the arithmetic mean of the sample + * data, not a known population parameter. This method is supplied only to save computation when + * the mean has already been computed. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @param mean the precomputed mean value + * @return the variance of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double variance(final double[] values, final double mean) + throws MathIllegalArgumentException { + return VARIANCE.evaluate(values, mean); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance">population + * variance</a> of the entries in the input array, or <code>Double.NaN</code> if the array is + * empty. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * formula and computing algorithm. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @return the population variance of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double populationVariance(final double[] values) + throws MathIllegalArgumentException { + return new Variance(false).evaluate(values); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance">population + * variance</a> of the entries in the specified portion of the input array, or <code>Double.NaN + * </code> if the designated subarray is empty. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the population variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double populationVariance( + final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return new Variance(false).evaluate(values, begin, length); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance">population + * variance</a> of the entries in the specified portion of the input array, using the + * precomputed mean value. Returns <code>Double.NaN</code> if the designated subarray is empty. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>The formula used assumes that the supplied mean value is the arithmetic mean of the sample + * data, not a known population parameter. This method is supplied only to save computation when + * the mean has already been computed. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * @param values the input array + * @param mean the precomputed mean value + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the population variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double populationVariance( + final double[] values, final double mean, final int begin, final int length) + throws MathIllegalArgumentException { + return new Variance(false).evaluate(values, mean, begin, length); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance">population + * variance</a> of the entries in the input array, using the precomputed mean value. Returns + * <code>Double.NaN</code> if the array is empty. + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.moment.Variance} for details on the + * computing algorithm. + * + * <p>The formula used assumes that the supplied mean value is the arithmetic mean of the sample + * data, not a known population parameter. This method is supplied only to save computation when + * the mean has already been computed. + * + * <p>Returns 0 for a single-value (i.e. length = 1) sample. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * @param values the input array + * @param mean the precomputed mean value + * @return the population variance of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double populationVariance(final double[] values, final double mean) + throws MathIllegalArgumentException { + return new Variance(false).evaluate(values, mean); + } + + /** + * Returns the maximum of the entries in the input array, or <code>Double.NaN</code> if the + * array is empty. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * <p> + * + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> (i.e. <code>NaN + * </code> values have no impact on the value of the statistic). + * <li>If any of the values equals <code>Double.POSITIVE_INFINITY</code>, the result is <code> + * Double.POSITIVE_INFINITY.</code> + * </ul> + * + * @param values the input array + * @return the maximum of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double max(final double[] values) throws MathIllegalArgumentException { + return MAX.evaluate(values); + } + + /** + * Returns the maximum of the entries in the specified portion of the input array, or <code> + * Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * <p> + * + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> (i.e. <code>NaN + * </code> values have no impact on the value of the statistic). + * <li>If any of the values equals <code>Double.POSITIVE_INFINITY</code>, the result is <code> + * Double.POSITIVE_INFINITY.</code> + * </ul> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the maximum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double max(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return MAX.evaluate(values, begin, length); + } + + /** + * Returns the minimum of the entries in the input array, or <code>Double.NaN</code> if the + * array is empty. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null. + * + * <p> + * + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> (i.e. <code>NaN + * </code> values have no impact on the value of the statistic). + * <li>If any of the values equals <code>Double.NEGATIVE_INFINITY</code>, the result is <code> + * Double.NEGATIVE_INFINITY.</code> + * </ul> + * + * @param values the input array + * @return the minimum of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public static double min(final double[] values) throws MathIllegalArgumentException { + return MIN.evaluate(values); + } + + /** + * Returns the minimum of the entries in the specified portion of the input array, or <code> + * Double.NaN</code> if the designated subarray is empty. + * + * <p>Throws <code>MathIllegalArgumentException</code> if the array is null or the array index + * parameters are not valid. + * + * <p> + * + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> (i.e. <code>NaN + * </code> values have no impact on the value of the statistic). + * <li>If any of the values equals <code>Double.NEGATIVE_INFINITY</code>, the result is <code> + * Double.NEGATIVE_INFINITY.</code> + * </ul> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the minimum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index parameters are + * not valid + */ + public static double min(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return MIN.evaluate(values, begin, length); + } + + /** + * Returns an estimate of the <code>p</code>th percentile of the values in the <code>values + * </code> array. + * + * <p> + * + * <ul> + * <li>Returns <code>Double.NaN</code> if <code>values</code> has length <code>0</code> + * <li>Returns (for any value of <code>p</code>) <code>values[0]</code> if <code>values</code> + * has length <code>1</code> + * <li>Throws <code>IllegalArgumentException</code> if <code>values</code> is null or p is not + * a valid quantile value (p must be greater than 0 and less than or equal to 100) + * </ul> + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.rank.Percentile} for a description of + * the percentile estimation algorithm used. + * + * @param values input array of values + * @param p the percentile value to compute + * @return the percentile value or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if <code>values</code> is null or p is invalid + */ + public static double percentile(final double[] values, final double p) + throws MathIllegalArgumentException { + return PERCENTILE.evaluate(values, p); + } + + /** + * Returns an estimate of the <code>p</code>th percentile of the values in the <code>values + * </code> array, starting with the element in (0-based) position <code>begin</code> in the + * array and including <code>length</code> values. + * + * <p> + * + * <ul> + * <li>Returns <code>Double.NaN</code> if <code>length = 0</code> + * <li>Returns (for any value of <code>p</code>) <code>values[begin]</code> if <code> + * length = 1 </code> + * <li>Throws <code>MathIllegalArgumentException</code> if <code>values</code> is null , + * <code>begin</code> or <code>length</code> is invalid, or <code>p</code> is not a valid + * quantile value (p must be greater than 0 and less than or equal to 100) + * </ul> + * + * <p>See {@link org.apache.commons.math3.stat.descriptive.rank.Percentile} for a description of + * the percentile estimation algorithm used. + * + * @param values array of input values + * @param p the percentile to compute + * @param begin the first (0-based) element to include in the computation + * @param length the number of array elements to include + * @return the percentile value + * @throws MathIllegalArgumentException if the parameters are not valid or the input array is + * null + */ + public static double percentile( + final double[] values, final int begin, final int length, final double p) + throws MathIllegalArgumentException { + return PERCENTILE.evaluate(values, begin, length, p); + } + + /** + * Returns the sum of the (signed) differences between corresponding elements of the input + * arrays -- i.e., sum(sample1[i] - sample2[i]). + * + * @param sample1 the first array + * @param sample2 the second array + * @return sum of paired differences + * @throws DimensionMismatchException if the arrays do not have the same (positive) length. + * @throws NoDataException if the sample arrays are empty. + */ + public static double sumDifference(final double[] sample1, final double[] sample2) + throws DimensionMismatchException, NoDataException { + int n = sample1.length; + if (n != sample2.length) { + throw new DimensionMismatchException(n, sample2.length); + } + if (n <= 0) { + throw new NoDataException(LocalizedFormats.INSUFFICIENT_DIMENSION); + } + double result = 0; + for (int i = 0; i < n; i++) { + result += sample1[i] - sample2[i]; + } + return result; + } + + /** + * Returns the mean of the (signed) differences between corresponding elements of the input + * arrays -- i.e., sum(sample1[i] - sample2[i]) / sample1.length. + * + * @param sample1 the first array + * @param sample2 the second array + * @return mean of paired differences + * @throws DimensionMismatchException if the arrays do not have the same (positive) length. + * @throws NoDataException if the sample arrays are empty. + */ + public static double meanDifference(final double[] sample1, final double[] sample2) + throws DimensionMismatchException, NoDataException { + return sumDifference(sample1, sample2) / sample1.length; + } + + /** + * Returns the variance of the (signed) differences between corresponding elements of the input + * arrays -- i.e., var(sample1[i] - sample2[i]). + * + * @param sample1 the first array + * @param sample2 the second array + * @param meanDifference the mean difference between corresponding entries + * @see #meanDifference(double[],double[]) + * @return variance of paired differences + * @throws DimensionMismatchException if the arrays do not have the same length. + * @throws NumberIsTooSmallException if the arrays length is less than 2. + */ + public static double varianceDifference( + final double[] sample1, final double[] sample2, double meanDifference) + throws DimensionMismatchException, NumberIsTooSmallException { + double sum1 = 0d; + double sum2 = 0d; + double diff = 0d; + int n = sample1.length; + if (n != sample2.length) { + throw new DimensionMismatchException(n, sample2.length); + } + if (n < 2) { + throw new NumberIsTooSmallException(n, 2, true); + } + for (int i = 0; i < n; i++) { + diff = sample1[i] - sample2[i]; + sum1 += (diff - meanDifference) * (diff - meanDifference); + sum2 += diff - meanDifference; + } + return (sum1 - (sum2 * sum2 / n)) / (n - 1); + } + + /** + * Normalize (standardize) the sample, so it is has a mean of 0 and a standard deviation of 1. + * + * @param sample Sample to normalize. + * @return normalized (standardized) sample. + * @since 2.2 + */ + public static double[] normalize(final double[] sample) { + DescriptiveStatistics stats = new DescriptiveStatistics(); + + // Add the data from the series to stats + for (int i = 0; i < sample.length; i++) { + stats.addValue(sample[i]); + } + + // Compute mean and standard deviation + double mean = stats.getMean(); + double standardDeviation = stats.getStandardDeviation(); + + // initialize the standardizedSample, which has the same length as the sample + double[] standardizedSample = new double[sample.length]; + + for (int i = 0; i < sample.length; i++) { + // z = (x- mean)/standardDeviation + standardizedSample[i] = (sample[i] - mean) / standardDeviation; + } + return standardizedSample; + } + + /** + * Returns the sample mode(s). The mode is the most frequently occurring value in the sample. If + * there is a unique value with maximum frequency, this value is returned as the only element of + * the output array. Otherwise, the returned array contains the maximum frequency elements in + * increasing order. For example, if {@code sample} is {0, 12, 5, 6, 0, 13, 5, 17}, the returned + * array will have length two, with 0 in the first element and 5 in the second. + * + * <p>NaN values are ignored when computing the mode - i.e., NaNs will never appear in the + * output array. If the sample includes only NaNs or has length 0, an empty array is returned. + * + * @param sample input data + * @return array of array of the most frequently occurring element(s) sorted in ascending order. + * @throws MathIllegalArgumentException if the indices are invalid or the array is null + * @since 3.3 + */ + public static double[] mode(double[] sample) throws MathIllegalArgumentException { + if (sample == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + return getMode(sample, 0, sample.length); + } + + /** + * Returns the sample mode(s). The mode is the most frequently occurring value in the sample. If + * there is a unique value with maximum frequency, this value is returned as the only element of + * the output array. Otherwise, the returned array contains the maximum frequency elements in + * increasing order. For example, if {@code sample} is {0, 12, 5, 6, 0, 13, 5, 17}, the returned + * array will have length two, with 0 in the first element and 5 in the second. + * + * <p>NaN values are ignored when computing the mode - i.e., NaNs will never appear in the + * output array. If the sample includes only NaNs or has length 0, an empty array is returned. + * + * @param sample input data + * @param begin index (0-based) of the first array element to include + * @param length the number of elements to include + * @return array of array of the most frequently occurring element(s) sorted in ascending order. + * @throws MathIllegalArgumentException if the indices are invalid or the array is null + * @since 3.3 + */ + public static double[] mode(double[] sample, final int begin, final int length) { + if (sample == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + + if (begin < 0) { + throw new NotPositiveException(LocalizedFormats.START_POSITION, Integer.valueOf(begin)); + } + + if (length < 0) { + throw new NotPositiveException(LocalizedFormats.LENGTH, Integer.valueOf(length)); + } + + return getMode(sample, begin, length); + } + + /** + * Private helper method. Assumes parameters have been validated. + * + * @param values input data + * @param begin index (0-based) of the first array element to include + * @param length the number of elements to include + * @return array of array of the most frequently occurring element(s) sorted in ascending order. + */ + private static double[] getMode(double[] values, final int begin, final int length) { + // Add the values to the frequency table + Frequency freq = new Frequency(); + for (int i = begin; i < begin + length; i++) { + final double value = values[i]; + if (!Double.isNaN(value)) { + freq.addValue(Double.valueOf(value)); + } + } + List<Comparable<?>> list = freq.getMode(); + // Convert the list to an array of primitive double + double[] modes = new double[list.size()]; + int i = 0; + for (Comparable<?> c : list) { + modes[i++] = ((Double) c).doubleValue(); + } + return modes; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/stat/clustering/Cluster.java new file mode 100644 index 0000000..8d9483e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/Cluster.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Cluster holding a set of {@link Clusterable} points. + * @param <T> the type of points that can be clustered + * @since 2.0 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.Cluster} instead + */ +@Deprecated +public class Cluster<T extends Clusterable<T>> implements Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3442297081515880464L; + + /** The points contained in this cluster. */ + private final List<T> points; + + /** Center of the cluster. */ + private final T center; + + /** + * Build a cluster centered at a specified point. + * @param center the point which is to be the center of this cluster + */ + public Cluster(final T center) { + this.center = center; + points = new ArrayList<T>(); + } + + /** + * Add a point to this cluster. + * @param point point to add + */ + public void addPoint(final T point) { + points.add(point); + } + + /** + * Get the points contained in the cluster. + * @return points contained in the cluster + */ + public List<T> getPoints() { + return points; + } + + /** + * Get the point chosen to be the center of this cluster. + * @return chosen cluster center + */ + public T getCenter() { + return center; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/stat/clustering/Clusterable.java new file mode 100644 index 0000000..f9818f3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/Clusterable.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.clustering; + +import java.util.Collection; + +/** + * Interface for points that can be clustered together. + * @param <T> the type of point that can be clustered + * @since 2.0 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.Clusterable} instead + */ +@Deprecated +public interface Clusterable<T> { + + /** + * Returns the distance from the given point. + * + * @param p the point to compute the distance from + * @return the distance from the given point + */ + double distanceFrom(T p); + + /** + * Returns the centroid of the given Collection of points. + * + * @param p the Collection of points to compute the centroid of + * @return the centroid of the given Collection of Points + */ + T centroidOf(Collection<T> p); + +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/stat/clustering/DBSCANClusterer.java new file mode 100644 index 0000000..13247eb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/DBSCANClusterer.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + +/** + * DBSCAN (density-based spatial clustering of applications with noise) algorithm. + * <p> + * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e. + * a point p is density connected to another point q, if there exists a chain of + * points p<sub>i</sub>, with i = 1 .. n and p<sub>1</sub> = p and p<sub>n</sub> = q, + * such that each pair <p<sub>i</sub>, p<sub>i+1</sub>> is directly density-reachable. + * A point q is directly density-reachable from point p if it is in the ε-neighborhood + * of this point. + * <p> + * Any point that is not density-reachable from a formed cluster is treated as noise, and + * will thus not be present in the result. + * <p> + * The algorithm requires two parameters: + * <ul> + * <li>eps: the distance that defines the ε-neighborhood of a point + * <li>minPoints: the minimum number of density-connected points required to form a cluster + * </ul> + * <p> + * <b>Note:</b> as DBSCAN is not a centroid-based clustering algorithm, the resulting + * {@link Cluster} objects will have no defined center, i.e. {@link Cluster#getCenter()} will + * return {@code null}. + * + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/DBSCAN">DBSCAN (wikipedia)</a> + * @see <a href="http://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf"> + * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise</a> + * @since 3.1 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.DBSCANClusterer} instead + */ +@Deprecated +public class DBSCANClusterer<T extends Clusterable<T>> { + + /** Maximum radius of the neighborhood to be considered. */ + private final double eps; + + /** Minimum number of points needed for a cluster. */ + private final int minPts; + + /** Status of a point during the clustering process. */ + private enum PointStatus { + /** The point has is considered to be noise. */ + NOISE, + /** The point is already part of a cluster. */ + PART_OF_CLUSTER + } + + /** + * Creates a new instance of a DBSCANClusterer. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts) + throws NotPositiveException { + if (eps < 0.0d) { + throw new NotPositiveException(eps); + } + if (minPts < 0) { + throw new NotPositiveException(minPts); + } + this.eps = eps; + this.minPts = minPts; + } + + /** + * Returns the maximum radius of the neighborhood to be considered. + * + * @return maximum radius of the neighborhood + */ + public double getEps() { + return eps; + } + + /** + * Returns the minimum number of points needed for a cluster. + * + * @return minimum number of points needed for a cluster + */ + public int getMinPts() { + return minPts; + } + + /** + * Performs DBSCAN cluster analysis. + * <p> + * <b>Note:</b> as DBSCAN is not a centroid-based clustering algorithm, the resulting + * {@link Cluster} objects will have no defined center, i.e. {@link Cluster#getCenter()} will + * return {@code null}. + * + * @param points the points to cluster + * @return the list of clusters + * @throws NullArgumentException if the data points are null + */ + public List<Cluster<T>> cluster(final Collection<T> points) throws NullArgumentException { + + // sanity checks + MathUtils.checkNotNull(points); + + final List<Cluster<T>> clusters = new ArrayList<Cluster<T>>(); + final Map<Clusterable<T>, PointStatus> visited = new HashMap<Clusterable<T>, PointStatus>(); + + for (final T point : points) { + if (visited.get(point) != null) { + continue; + } + final List<T> neighbors = getNeighbors(point, points); + if (neighbors.size() >= minPts) { + // DBSCAN does not care about center points + final Cluster<T> cluster = new Cluster<T>(null); + clusters.add(expandCluster(cluster, point, neighbors, points, visited)); + } else { + visited.put(point, PointStatus.NOISE); + } + } + + return clusters; + } + + /** + * Expands the cluster to include density-reachable items. + * + * @param cluster Cluster to expand + * @param point Point to add to cluster + * @param neighbors List of neighbors + * @param points the data set + * @param visited the set of already visited points + * @return the expanded cluster + */ + private Cluster<T> expandCluster(final Cluster<T> cluster, + final T point, + final List<T> neighbors, + final Collection<T> points, + final Map<Clusterable<T>, PointStatus> visited) { + cluster.addPoint(point); + visited.put(point, PointStatus.PART_OF_CLUSTER); + + List<T> seeds = new ArrayList<T>(neighbors); + int index = 0; + while (index < seeds.size()) { + final T current = seeds.get(index); + PointStatus pStatus = visited.get(current); + // only check non-visited points + if (pStatus == null) { + final List<T> currentNeighbors = getNeighbors(current, points); + if (currentNeighbors.size() >= minPts) { + seeds = merge(seeds, currentNeighbors); + } + } + + if (pStatus != PointStatus.PART_OF_CLUSTER) { + visited.put(current, PointStatus.PART_OF_CLUSTER); + cluster.addPoint(current); + } + + index++; + } + return cluster; + } + + /** + * Returns a list of density-reachable neighbors of a {@code point}. + * + * @param point the point to look for + * @param points possible neighbors + * @return the List of neighbors + */ + private List<T> getNeighbors(final T point, final Collection<T> points) { + final List<T> neighbors = new ArrayList<T>(); + for (final T neighbor : points) { + if (point != neighbor && neighbor.distanceFrom(point) <= eps) { + neighbors.add(neighbor); + } + } + return neighbors; + } + + /** + * Merges two lists together. + * + * @param one first list + * @param two second list + * @return merged lists + */ + private List<T> merge(final List<T> one, final List<T> two) { + final Set<T> oneSet = new HashSet<T>(one); + for (T item : two) { + if (!oneSet.contains(item)) { + one.add(item); + } + } + return one; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanDoublePoint.java b/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanDoublePoint.java new file mode 100644 index 0000000..32c236c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanDoublePoint.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.clustering; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Arrays; + +import org.apache.commons.math3.util.MathArrays; + +/** + * A simple implementation of {@link Clusterable} for points with double coordinates. + * @since 3.1 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.DoublePoint} instead + */ +@Deprecated +public class EuclideanDoublePoint implements Clusterable<EuclideanDoublePoint>, Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 8026472786091227632L; + + /** Point coordinates. */ + private final double[] point; + + /** + * Build an instance wrapping an integer array. + * <p> + * The wrapped array is referenced, it is <em>not</em> copied. + * + * @param point the n-dimensional point in integer space + */ + public EuclideanDoublePoint(final double[] point) { + this.point = point; + } + + /** {@inheritDoc} */ + public EuclideanDoublePoint centroidOf(final Collection<EuclideanDoublePoint> points) { + final double[] centroid = new double[getPoint().length]; + for (final EuclideanDoublePoint p : points) { + for (int i = 0; i < centroid.length; i++) { + centroid[i] += p.getPoint()[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new EuclideanDoublePoint(centroid); + } + + /** {@inheritDoc} */ + public double distanceFrom(final EuclideanDoublePoint p) { + return MathArrays.distance(point, p.getPoint()); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(final Object other) { + if (!(other instanceof EuclideanDoublePoint)) { + return false; + } + return Arrays.equals(point, ((EuclideanDoublePoint) other).point); + } + + /** + * Get the n-dimensional point in integer space. + * + * @return a reference (not a copy!) to the wrapped array + */ + public double[] getPoint() { + return point; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(point); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return Arrays.toString(point); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanIntegerPoint.java b/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanIntegerPoint.java new file mode 100644 index 0000000..508b0fa --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/EuclideanIntegerPoint.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.clustering; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.math3.util.MathArrays; + +/** + * A simple implementation of {@link Clusterable} for points with integer coordinates. + * @since 2.0 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.DoublePoint} instead + */ +@Deprecated +public class EuclideanIntegerPoint implements Clusterable<EuclideanIntegerPoint>, Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 3946024775784901369L; + + /** Point coordinates. */ + private final int[] point; + + /** + * Build an instance wrapping an integer array. + * <p>The wrapped array is referenced, it is <em>not</em> copied.</p> + * @param point the n-dimensional point in integer space + */ + public EuclideanIntegerPoint(final int[] point) { + this.point = point; + } + + /** + * Get the n-dimensional point in integer space. + * @return a reference (not a copy!) to the wrapped array + */ + public int[] getPoint() { + return point; + } + + /** {@inheritDoc} */ + public double distanceFrom(final EuclideanIntegerPoint p) { + return MathArrays.distance(point, p.getPoint()); + } + + /** {@inheritDoc} */ + public EuclideanIntegerPoint centroidOf(final Collection<EuclideanIntegerPoint> points) { + int[] centroid = new int[getPoint().length]; + for (EuclideanIntegerPoint p : points) { + for (int i = 0; i < centroid.length; i++) { + centroid[i] += p.getPoint()[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new EuclideanIntegerPoint(centroid); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(final Object other) { + if (!(other instanceof EuclideanIntegerPoint)) { + return false; + } + return Arrays.equals(point, ((EuclideanIntegerPoint) other).point); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(point); + } + + /** + * {@inheritDoc} + * @since 2.1 + */ + @Override + public String toString() { + return Arrays.toString(point); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/stat/clustering/KMeansPlusPlusClusterer.java new file mode 100644 index 0000000..07cec09 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/KMeansPlusPlusClusterer.java @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.util.MathUtils; + +/** + * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> + * @since 2.0 + * @deprecated As of 3.2 (to be removed in 4.0), + * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead + */ +@Deprecated +public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { + + /** Strategies to use for replacing an empty cluster. */ + public enum EmptyClusterStrategy { + + /** Split the cluster with largest distance variance. */ + LARGEST_VARIANCE, + + /** Split the cluster with largest number of points. */ + LARGEST_POINTS_NUMBER, + + /** Create a cluster around the point farthest from its centroid. */ + FARTHEST_POINT, + + /** Generate an error. */ + ERROR + + } + + /** Random generator for choosing initial centers. */ + private final Random random; + + /** Selected strategy for empty clusters. */ + private final EmptyClusterStrategy emptyStrategy; + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * </p> + * @param random random generator to use for choosing initial centers + */ + public KMeansPlusPlusClusterer(final Random random) { + this(random, EmptyClusterStrategy.LARGEST_VARIANCE); + } + + /** Build a clusterer. + * @param random random generator to use for choosing initial centers + * @param emptyStrategy strategy to use for handling empty clusters that + * may appear during algorithm iterations + * @since 2.2 + */ + public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { + this.random = random; + this.emptyStrategy = emptyStrategy; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @param k the number of clusters to split the data into + * @param numTrials number of trial runs + * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm + * for at each trial run. If negative, no maximum will be used + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + public List<Cluster<T>> cluster(final Collection<T> points, final int k, + int numTrials, int maxIterationsPerTrial) + throws MathIllegalArgumentException, ConvergenceException { + + // at first, we have not found any clusters list yet + List<Cluster<T>> best = null; + double bestVarianceSum = Double.POSITIVE_INFINITY; + + // do several clustering trials + for (int i = 0; i < numTrials; ++i) { + + // compute a clusters list + List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial); + + // compute the variance of the current list + double varianceSum = 0.0; + for (final Cluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final T center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(point.distanceFrom(center)); + } + varianceSum += stat.getResult(); + + } + } + + if (varianceSum <= bestVarianceSum) { + // this one is the best we have found so far, remember it + best = clusters; + bestVarianceSum = varianceSum; + } + + } + + // return the best clusters list found + return best; + + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm + * for. If negative, no maximum will be used + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + public List<Cluster<T>> cluster(final Collection<T> points, final int k, + final int maxIterations) + throws MathIllegalArgumentException, ConvergenceException { + + // sanity checks + MathUtils.checkNotNull(points); + + // number of clusters has to be smaller or equal the number of data points + if (points.size() < k) { + throw new NumberIsTooSmallException(points.size(), k, false); + } + + // create the initial clusters + List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); + + // create an array containing the latest assignment of a point to a cluster + // no need to initialize the array, as it will be filled with the first assignment + int[] assignments = new int[points.size()]; + assignPointsToClusters(clusters, points, assignments); + + // iterate through updating the centers until we're done + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + for (int count = 0; count < max; count++) { + boolean emptyCluster = false; + List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); + for (final Cluster<T> cluster : clusters) { + final T newCenter; + if (cluster.getPoints().isEmpty()) { + switch (emptyStrategy) { + case LARGEST_VARIANCE : + newCenter = getPointFromLargestVarianceCluster(clusters); + break; + case LARGEST_POINTS_NUMBER : + newCenter = getPointFromLargestNumberCluster(clusters); + break; + case FARTHEST_POINT : + newCenter = getFarthestPoint(clusters); + break; + default : + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + emptyCluster = true; + } else { + newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); + } + newClusters.add(new Cluster<T>(newCenter)); + } + int changes = assignPointsToClusters(newClusters, points, assignments); + clusters = newClusters; + + // if there were no more changes in the point-to-cluster assignment + // and there are no empty clusters left, return the current clusters + if (changes == 0 && !emptyCluster) { + return clusters; + } + } + return clusters; + } + + /** + * Adds the given points to the closest {@link Cluster}. + * + * @param <T> type of the points to cluster + * @param clusters the {@link Cluster}s to add the points to + * @param points the points to add to the given {@link Cluster}s + * @param assignments points assignments to clusters + * @return the number of points assigned to different clusters as the iteration before + */ + private static <T extends Clusterable<T>> int + assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points, + final int[] assignments) { + int assignedDifferently = 0; + int pointIndex = 0; + for (final T p : points) { + int clusterIndex = getNearestCluster(clusters, p); + if (clusterIndex != assignments[pointIndex]) { + assignedDifferently++; + } + + Cluster<T> cluster = clusters.get(clusterIndex); + cluster.addPoint(p); + assignments[pointIndex++] = clusterIndex; + } + + return assignedDifferently; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param <T> type of the points to cluster + * @param points the points to choose the initial centers from + * @param k the number of centers to choose + * @param random random generator to use + * @return the initial centers + */ + private static <T extends Clusterable<T>> List<Cluster<T>> + chooseInitialCenters(final Collection<T> points, final int k, final Random random) { + + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new Cluster<T>(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = firstPoint.distanceFrom(pointList.get(i)); + minDistSquared[i] = d*d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new Cluster<T> (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = p.distanceFrom(pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) + throws ConvergenceException { + + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster<T> selected = null; + for (final Cluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final T center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(point.distanceFrom(center)); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get a random point from the {@link Cluster} with the largest number of points + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException { + + int maxNumber = 0; + Cluster<T> selected = null; + for (final Cluster<T> cluster : clusters) { + + // get the number of points of the current cluster + final int number = cluster.getPoints().size(); + + // select the cluster with the largest number of points + if (number > maxNumber) { + maxNumber = number; + selected = cluster; + } + + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get the point farthest to its cluster center + * + * @param clusters the {@link Cluster}s to search + * @return point farthest to its cluster center + * @throws ConvergenceException if clusters are all empty + */ + private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException { + + double maxDistance = Double.NEGATIVE_INFINITY; + Cluster<T> selectedCluster = null; + int selectedPoint = -1; + for (final Cluster<T> cluster : clusters) { + + // get the farthest point + final T center = cluster.getCenter(); + final List<T> points = cluster.getPoints(); + for (int i = 0; i < points.size(); ++i) { + final double distance = points.get(i).distanceFrom(center); + if (distance > maxDistance) { + maxDistance = distance; + selectedCluster = cluster; + selectedPoint = i; + } + } + + } + + // did we find at least one non-empty cluster ? + if (selectedCluster == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + return selectedCluster.getPoints().remove(selectedPoint); + + } + + /** + * Returns the nearest {@link Cluster} to the given point + * + * @param <T> type of the points to cluster + * @param clusters the {@link Cluster}s to search + * @param point the point to find the nearest {@link Cluster} for + * @return the index of the nearest {@link Cluster} to the given point + */ + private static <T extends Clusterable<T>> int + getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { + double minDistance = Double.MAX_VALUE; + int clusterIndex = 0; + int minCluster = 0; + for (final Cluster<T> c : clusters) { + final double distance = point.distanceFrom(c.getCenter()); + if (distance < minDistance) { + minDistance = distance; + minCluster = clusterIndex; + } + clusterIndex++; + } + return minCluster; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/clustering/package-info.java b/src/main/java/org/apache/commons/math3/stat/clustering/package-info.java new file mode 100644 index 0000000..f6b8d3e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/clustering/package-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * <h2>All classes and sub-packages of this package are deprecated.</h2> + * <h3>Please use their replacements, to be found under + * <ul> + * <li>{@link org.apache.commons.math3.ml.clustering}</li> + * </ul> + * </h3> + * + * <p> + * Clustering algorithms. + * </p> + */ +package org.apache.commons.math3.stat.clustering; diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/Covariance.java b/src/main/java/org/apache/commons/math3/stat/correlation/Covariance.java new file mode 100644 index 0000000..c462401 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/Covariance.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.Variance; + +/** + * Computes covariances for pairs of arrays or columns of a matrix. + * + * <p>The constructors that take <code>RealMatrix</code> or + * <code>double[][]</code> arguments generate covariance matrices. The + * columns of the input matrices are assumed to represent variable values.</p> + * + * <p>The constructor argument <code>biasCorrected</code> determines whether or + * not computed covariances are bias-corrected.</p> + * + * <p>Unbiased covariances are given by the formula</p> + * <code>cov(X, Y) = Σ[(x<sub>i</sub> - E(X))(y<sub>i</sub> - E(Y))] / (n - 1)</code> + * where <code>E(X)</code> is the mean of <code>X</code> and <code>E(Y)</code> + * is the mean of the <code>Y</code> values. + * + * <p>Non-bias-corrected estimates use <code>n</code> in place of <code>n - 1</code> + * + * @since 2.0 + */ +public class Covariance { + + /** covariance matrix */ + private final RealMatrix covarianceMatrix; + + /** + * Create an empty covariance matrix. + */ + /** Number of observations (length of covariate vectors) */ + private final int n; + + /** + * Create a Covariance with no data + */ + public Covariance() { + super(); + covarianceMatrix = null; + n = 0; + } + + /** + * Create a Covariance matrix from a rectangular array + * whose columns represent covariates. + * + * <p>The <code>biasCorrected</code> parameter determines whether or not + * covariance estimates are bias-corrected.</p> + * + * <p>The input array must be rectangular with at least one column + * and two rows.</p> + * + * @param data rectangular array with columns representing covariates + * @param biasCorrected true means covariances are bias-corrected + * @throws MathIllegalArgumentException if the input data array is not + * rectangular with at least two rows and one column. + * @throws NotStrictlyPositiveException if the input data array is not + * rectangular with at least one row and one column. + */ + public Covariance(double[][] data, boolean biasCorrected) + throws MathIllegalArgumentException, NotStrictlyPositiveException { + this(new BlockRealMatrix(data), biasCorrected); + } + + /** + * Create a Covariance matrix from a rectangular array + * whose columns represent covariates. + * + * <p>The input array must be rectangular with at least one column + * and two rows</p> + * + * @param data rectangular array with columns representing covariates + * @throws MathIllegalArgumentException if the input data array is not + * rectangular with at least two rows and one column. + * @throws NotStrictlyPositiveException if the input data array is not + * rectangular with at least one row and one column. + */ + public Covariance(double[][] data) + throws MathIllegalArgumentException, NotStrictlyPositiveException { + this(data, true); + } + + /** + * Create a covariance matrix from a matrix whose columns + * represent covariates. + * + * <p>The <code>biasCorrected</code> parameter determines whether or not + * covariance estimates are bias-corrected.</p> + * + * <p>The matrix must have at least one column and two rows</p> + * + * @param matrix matrix with columns representing covariates + * @param biasCorrected true means covariances are bias-corrected + * @throws MathIllegalArgumentException if the input matrix does not have + * at least two rows and one column + */ + public Covariance(RealMatrix matrix, boolean biasCorrected) + throws MathIllegalArgumentException { + checkSufficientData(matrix); + n = matrix.getRowDimension(); + covarianceMatrix = computeCovarianceMatrix(matrix, biasCorrected); + } + + /** + * Create a covariance matrix from a matrix whose columns + * represent covariates. + * + * <p>The matrix must have at least one column and two rows</p> + * + * @param matrix matrix with columns representing covariates + * @throws MathIllegalArgumentException if the input matrix does not have + * at least two rows and one column + */ + public Covariance(RealMatrix matrix) throws MathIllegalArgumentException { + this(matrix, true); + } + + /** + * Returns the covariance matrix + * + * @return covariance matrix + */ + public RealMatrix getCovarianceMatrix() { + return covarianceMatrix; + } + + /** + * Returns the number of observations (length of covariate vectors) + * + * @return number of observations + */ + public int getN() { + return n; + } + + /** + * Compute a covariance matrix from a matrix whose columns represent + * covariates. + * @param matrix input matrix (must have at least one column and two rows) + * @param biasCorrected determines whether or not covariance estimates are bias-corrected + * @return covariance matrix + * @throws MathIllegalArgumentException if the matrix does not contain sufficient data + */ + protected RealMatrix computeCovarianceMatrix(RealMatrix matrix, boolean biasCorrected) + throws MathIllegalArgumentException { + int dimension = matrix.getColumnDimension(); + Variance variance = new Variance(biasCorrected); + RealMatrix outMatrix = new BlockRealMatrix(dimension, dimension); + for (int i = 0; i < dimension; i++) { + for (int j = 0; j < i; j++) { + double cov = covariance(matrix.getColumn(i), matrix.getColumn(j), biasCorrected); + outMatrix.setEntry(i, j, cov); + outMatrix.setEntry(j, i, cov); + } + outMatrix.setEntry(i, i, variance.evaluate(matrix.getColumn(i))); + } + return outMatrix; + } + + /** + * Create a covariance matrix from a matrix whose columns represent + * covariates. Covariances are computed using the bias-corrected formula. + * @param matrix input matrix (must have at least one column and two rows) + * @return covariance matrix + * @throws MathIllegalArgumentException if matrix does not contain sufficient data + * @see #Covariance + */ + protected RealMatrix computeCovarianceMatrix(RealMatrix matrix) + throws MathIllegalArgumentException { + return computeCovarianceMatrix(matrix, true); + } + + /** + * Compute a covariance matrix from a rectangular array whose columns represent + * covariates. + * @param data input array (must have at least one column and two rows) + * @param biasCorrected determines whether or not covariance estimates are bias-corrected + * @return covariance matrix + * @throws MathIllegalArgumentException if the data array does not contain sufficient + * data + * @throws NotStrictlyPositiveException if the input data array is not + * rectangular with at least one row and one column. + */ + protected RealMatrix computeCovarianceMatrix(double[][] data, boolean biasCorrected) + throws MathIllegalArgumentException, NotStrictlyPositiveException { + return computeCovarianceMatrix(new BlockRealMatrix(data), biasCorrected); + } + + /** + * Create a covariance matrix from a rectangular array whose columns represent + * covariates. Covariances are computed using the bias-corrected formula. + * @param data input array (must have at least one column and two rows) + * @return covariance matrix + * @throws MathIllegalArgumentException if the data array does not contain sufficient data + * @throws NotStrictlyPositiveException if the input data array is not + * rectangular with at least one row and one column. + * @see #Covariance + */ + protected RealMatrix computeCovarianceMatrix(double[][] data) + throws MathIllegalArgumentException, NotStrictlyPositiveException { + return computeCovarianceMatrix(data, true); + } + + /** + * Computes the covariance between the two arrays. + * + * <p>Array lengths must match and the common length must be at least 2.</p> + * + * @param xArray first data array + * @param yArray second data array + * @param biasCorrected if true, returned value will be bias-corrected + * @return returns the covariance for the two arrays + * @throws MathIllegalArgumentException if the arrays lengths do not match or + * there is insufficient data + */ + public double covariance(final double[] xArray, final double[] yArray, boolean biasCorrected) + throws MathIllegalArgumentException { + Mean mean = new Mean(); + double result = 0d; + int length = xArray.length; + if (length != yArray.length) { + throw new MathIllegalArgumentException( + LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, length, yArray.length); + } else if (length < 2) { + throw new MathIllegalArgumentException( + LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, length, 2); + } else { + double xMean = mean.evaluate(xArray); + double yMean = mean.evaluate(yArray); + for (int i = 0; i < length; i++) { + double xDev = xArray[i] - xMean; + double yDev = yArray[i] - yMean; + result += (xDev * yDev - result) / (i + 1); + } + } + return biasCorrected ? result * ((double) length / (double)(length - 1)) : result; + } + + /** + * Computes the covariance between the two arrays, using the bias-corrected + * formula. + * + * <p>Array lengths must match and the common length must be at least 2.</p> + * + * @param xArray first data array + * @param yArray second data array + * @return returns the covariance for the two arrays + * @throws MathIllegalArgumentException if the arrays lengths do not match or + * there is insufficient data + */ + public double covariance(final double[] xArray, final double[] yArray) + throws MathIllegalArgumentException { + return covariance(xArray, yArray, true); + } + + /** + * Throws MathIllegalArgumentException if the matrix does not have at least + * one column and two rows. + * @param matrix matrix to check + * @throws MathIllegalArgumentException if the matrix does not contain sufficient data + * to compute covariance + */ + private void checkSufficientData(final RealMatrix matrix) throws MathIllegalArgumentException { + int nRows = matrix.getRowDimension(); + int nCols = matrix.getColumnDimension(); + if (nRows < 2 || nCols < 1) { + throw new MathIllegalArgumentException( + LocalizedFormats.INSUFFICIENT_ROWS_AND_COLUMNS, + nRows, nCols); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java new file mode 100644 index 0000000..d38cf71 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; + +import java.util.Arrays; +import java.util.Comparator; + +/** + * Implementation of Kendall's Tau-b rank correlation</a>. + * <p> + * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and + * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if + * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub> + * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>. + * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and + * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and + * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub> + * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor + * discordant. + * <p> + * Kendall's Tau-b is defined as: + * <pre> + * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>)) + * </pre> + * <p> + * where: + * <ul> + * <li>n<sub>0</sub> = n * (n - 1) / 2</li> + * <li>n<sub>c</sub> = Number of concordant pairs</li> + * <li>n<sub>d</sub> = Number of discordant pairs</li> + * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li> + * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li> + * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li> + * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li> + * </ul> + * <p> + * This implementation uses the O(n log n) algorithm described in + * William R. Knight's 1966 paper "A Computer Method for Calculating + * Kendall's Tau with Ungrouped Data" in the Journal of the American + * Statistical Association. + * + * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient"> + * Kendall tau rank correlation coefficient (Wikipedia)</a> + * @see <a href="http://www.jstor.org/stable/2282833">A Computer + * Method for Calculating Kendall's Tau with Ungrouped Data</a> + * + * @since 3.3 + */ +public class KendallsCorrelation { + + /** correlation matrix */ + private final RealMatrix correlationMatrix; + + /** + * Create a KendallsCorrelation instance without data. + */ + public KendallsCorrelation() { + correlationMatrix = null; + } + + /** + * Create a KendallsCorrelation from a rectangular array + * whose columns represent values of variables to be correlated. + * + * @param data rectangular array with columns representing variables + * @throws IllegalArgumentException if the input data array is not + * rectangular with at least two rows and two columns. + */ + public KendallsCorrelation(double[][] data) { + this(MatrixUtils.createRealMatrix(data)); + } + + /** + * Create a KendallsCorrelation from a RealMatrix whose columns + * represent variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + */ + public KendallsCorrelation(RealMatrix matrix) { + correlationMatrix = computeCorrelationMatrix(matrix); + } + + /** + * Returns the correlation matrix. + * + * @return correlation matrix + */ + public RealMatrix getCorrelationMatrix() { + return correlationMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input matrix. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { + int nVars = matrix.getColumnDimension(); + RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < i; j++) { + double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); + outMatrix.setEntry(i, j, corr); + outMatrix.setEntry(j, i, corr); + } + outMatrix.setEntry(i, i, 1d); + } + return outMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input rectangular array. The columns of the array represent values + * of variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final double[][] matrix) { + return computeCorrelationMatrix(new BlockRealMatrix(matrix)); + } + + /** + * Computes the Kendall's Tau rank correlation coefficient between the two arrays. + * + * @param xArray first data array + * @param yArray second data array + * @return Returns Kendall's Tau rank correlation coefficient for the two arrays + * @throws DimensionMismatchException if the arrays lengths do not match + */ + public double correlation(final double[] xArray, final double[] yArray) + throws DimensionMismatchException { + + if (xArray.length != yArray.length) { + throw new DimensionMismatchException(xArray.length, yArray.length); + } + + final int n = xArray.length; + final long numPairs = sum(n - 1); + + @SuppressWarnings("unchecked") + Pair<Double, Double>[] pairs = new Pair[n]; + for (int i = 0; i < n; i++) { + pairs[i] = new Pair<Double, Double>(xArray[i], yArray[i]); + } + + Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() { + /** {@inheritDoc} */ + public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) { + int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); + return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); + } + }); + + long tiedXPairs = 0; + long tiedXYPairs = 0; + long consecutiveXTies = 1; + long consecutiveXYTies = 1; + Pair<Double, Double> prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair<Double, Double> curr = pairs[i]; + if (curr.getFirst().equals(prev.getFirst())) { + consecutiveXTies++; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveXYTies++; + } else { + tiedXYPairs += sum(consecutiveXYTies - 1); + consecutiveXYTies = 1; + } + } else { + tiedXPairs += sum(consecutiveXTies - 1); + consecutiveXTies = 1; + tiedXYPairs += sum(consecutiveXYTies - 1); + consecutiveXYTies = 1; + } + prev = curr; + } + tiedXPairs += sum(consecutiveXTies - 1); + tiedXYPairs += sum(consecutiveXYTies - 1); + + long swaps = 0; + @SuppressWarnings("unchecked") + Pair<Double, Double>[] pairsDestination = new Pair[n]; + for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { + for (int offset = 0; offset < n; offset += 2 * segmentSize) { + int i = offset; + final int iEnd = FastMath.min(i + segmentSize, n); + int j = iEnd; + final int jEnd = FastMath.min(j + segmentSize, n); + + int copyLocation = offset; + while (i < iEnd || j < jEnd) { + if (i < iEnd) { + if (j < jEnd) { + if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { + pairsDestination[copyLocation] = pairs[i]; + i++; + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + swaps += iEnd - i; + } + } else { + pairsDestination[copyLocation] = pairs[i]; + i++; + } + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + } + copyLocation++; + } + } + final Pair<Double, Double>[] pairsTemp = pairs; + pairs = pairsDestination; + pairsDestination = pairsTemp; + } + + long tiedYPairs = 0; + long consecutiveYTies = 1; + prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair<Double, Double> curr = pairs[i]; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveYTies++; + } else { + tiedYPairs += sum(consecutiveYTies - 1); + consecutiveYTies = 1; + } + prev = curr; + } + tiedYPairs += sum(consecutiveYTies - 1); + + final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; + final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs); + return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied); + } + + /** + * Returns the sum of the number from 1 .. n according to Gauss' summation formula: + * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \] + * + * @param n the summation end + * @return the sum of the number from 1 to n + */ + private static long sum(long n) { + return n * (n + 1) / 2l; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/PearsonsCorrelation.java b/src/main/java/org/apache/commons/math3/stat/correlation/PearsonsCorrelation.java new file mode 100644 index 0000000..53d17ab --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/PearsonsCorrelation.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.distribution.TDistribution; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.apache.commons.math3.util.FastMath; + +/** + * Computes Pearson's product-moment correlation coefficients for pairs of arrays + * or columns of a matrix. + * + * <p>The constructors that take <code>RealMatrix</code> or + * <code>double[][]</code> arguments generate correlation matrices. The + * columns of the input matrices are assumed to represent variable values. + * Correlations are given by the formula</p> + * + * <p><code>cor(X, Y) = Σ[(x<sub>i</sub> - E(X))(y<sub>i</sub> - E(Y))] / [(n - 1)s(X)s(Y)]</code> + * where <code>E(X)</code> is the mean of <code>X</code>, <code>E(Y)</code> + * is the mean of the <code>Y</code> values and s(X), s(Y) are standard deviations.</p> + * + * <p>To compute the correlation coefficient for a single pair of arrays, use {@link #PearsonsCorrelation()} + * to construct an instance with no data and then {@link #correlation(double[], double[])}. + * Correlation matrices can also be computed directly from an instance with no data using + * {@link #computeCorrelationMatrix(double[][])}. In order to use {@link #getCorrelationMatrix()}, + * {@link #getCorrelationPValues()}, or {@link #getCorrelationStandardErrors()}; however, one of the + * constructors supplying data or a covariance matrix must be used to create the instance.</p> + * + * @since 2.0 + */ +public class PearsonsCorrelation { + + /** correlation matrix */ + private final RealMatrix correlationMatrix; + + /** number of observations */ + private final int nObs; + + /** + * Create a PearsonsCorrelation instance without data. + */ + public PearsonsCorrelation() { + super(); + correlationMatrix = null; + nObs = 0; + } + + /** + * Create a PearsonsCorrelation from a rectangular array + * whose columns represent values of variables to be correlated. + * + * Throws MathIllegalArgumentException if the input array does not have at least + * two columns and two rows. Pairwise correlations are set to NaN if one + * of the correlates has zero variance. + * + * @param data rectangular array with columns representing variables + * @throws MathIllegalArgumentException if the input data array is not + * rectangular with at least two rows and two columns. + * @see #correlation(double[], double[]) + */ + public PearsonsCorrelation(double[][] data) { + this(new BlockRealMatrix(data)); + } + + /** + * Create a PearsonsCorrelation from a RealMatrix whose columns + * represent variables to be correlated. + * + * Throws MathIllegalArgumentException if the matrix does not have at least + * two columns and two rows. Pairwise correlations are set to NaN if one + * of the correlates has zero variance. + * + * @param matrix matrix with columns representing variables to correlate + * @throws MathIllegalArgumentException if the matrix does not contain sufficient data + * @see #correlation(double[], double[]) + */ + public PearsonsCorrelation(RealMatrix matrix) { + nObs = matrix.getRowDimension(); + correlationMatrix = computeCorrelationMatrix(matrix); + } + + /** + * Create a PearsonsCorrelation from a {@link Covariance}. The correlation + * matrix is computed by scaling the Covariance's covariance matrix. + * The Covariance instance must have been created from a data matrix with + * columns representing variable values. + * + * @param covariance Covariance instance + */ + public PearsonsCorrelation(Covariance covariance) { + RealMatrix covarianceMatrix = covariance.getCovarianceMatrix(); + if (covarianceMatrix == null) { + throw new NullArgumentException(LocalizedFormats.COVARIANCE_MATRIX); + } + nObs = covariance.getN(); + correlationMatrix = covarianceToCorrelation(covarianceMatrix); + } + + /** + * Create a PearsonsCorrelation from a covariance matrix. The correlation + * matrix is computed by scaling the covariance matrix. + * + * @param covarianceMatrix covariance matrix + * @param numberOfObservations the number of observations in the dataset used to compute + * the covariance matrix + */ + public PearsonsCorrelation(RealMatrix covarianceMatrix, int numberOfObservations) { + nObs = numberOfObservations; + correlationMatrix = covarianceToCorrelation(covarianceMatrix); + } + + /** + * Returns the correlation matrix. + * + * <p>This method will return null if the argumentless constructor was used + * to create this instance, even if {@link #computeCorrelationMatrix(double[][])} + * has been called before it is activated.</p> + * + * @return correlation matrix + */ + public RealMatrix getCorrelationMatrix() { + return correlationMatrix; + } + + /** + * Returns a matrix of standard errors associated with the estimates + * in the correlation matrix.<br/> + * <code>getCorrelationStandardErrors().getEntry(i,j)</code> is the standard + * error associated with <code>getCorrelationMatrix.getEntry(i,j)</code> + * + * <p>The formula used to compute the standard error is <br/> + * <code>SE<sub>r</sub> = ((1 - r<sup>2</sup>) / (n - 2))<sup>1/2</sup></code> + * where <code>r</code> is the estimated correlation coefficient and + * <code>n</code> is the number of observations in the source dataset.</p> + * + * <p>To use this method, one of the constructors that supply an input + * matrix must have been used to create this instance.</p> + * + * @return matrix of correlation standard errors + * @throws NullPointerException if this instance was created with no data + */ + public RealMatrix getCorrelationStandardErrors() { + int nVars = correlationMatrix.getColumnDimension(); + double[][] out = new double[nVars][nVars]; + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < nVars; j++) { + double r = correlationMatrix.getEntry(i, j); + out[i][j] = FastMath.sqrt((1 - r * r) /(nObs - 2)); + } + } + return new BlockRealMatrix(out); + } + + /** + * Returns a matrix of p-values associated with the (two-sided) null + * hypothesis that the corresponding correlation coefficient is zero. + * + * <p><code>getCorrelationPValues().getEntry(i,j)</code> is the probability + * that a random variable distributed as <code>t<sub>n-2</sub></code> takes + * a value with absolute value greater than or equal to <br> + * <code>|r|((n - 2) / (1 - r<sup>2</sup>))<sup>1/2</sup></code></p> + * + * <p>The values in the matrix are sometimes referred to as the + * <i>significance</i> of the corresponding correlation coefficients.</p> + * + * <p>To use this method, one of the constructors that supply an input + * matrix must have been used to create this instance.</p> + * + * @return matrix of p-values + * @throws org.apache.commons.math3.exception.MaxCountExceededException + * if an error occurs estimating probabilities + * @throws NullPointerException if this instance was created with no data + */ + public RealMatrix getCorrelationPValues() { + TDistribution tDistribution = new TDistribution(nObs - 2); + int nVars = correlationMatrix.getColumnDimension(); + double[][] out = new double[nVars][nVars]; + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < nVars; j++) { + if (i == j) { + out[i][j] = 0d; + } else { + double r = correlationMatrix.getEntry(i, j); + double t = FastMath.abs(r * FastMath.sqrt((nObs - 2)/(1 - r * r))); + out[i][j] = 2 * tDistribution.cumulativeProbability(-t); + } + } + } + return new BlockRealMatrix(out); + } + + + /** + * Computes the correlation matrix for the columns of the + * input matrix, using {@link #correlation(double[], double[])}. + * + * Throws MathIllegalArgumentException if the matrix does not have at least + * two columns and two rows. Pairwise correlations are set to NaN if one + * of the correlates has zero variance. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + * @throws MathIllegalArgumentException if the matrix does not contain sufficient data + * @see #correlation(double[], double[]) + */ + public RealMatrix computeCorrelationMatrix(RealMatrix matrix) { + checkSufficientData(matrix); + int nVars = matrix.getColumnDimension(); + RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < i; j++) { + double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); + outMatrix.setEntry(i, j, corr); + outMatrix.setEntry(j, i, corr); + } + outMatrix.setEntry(i, i, 1d); + } + return outMatrix; + } + + /** + * Computes the correlation matrix for the columns of the + * input rectangular array. The columns of the array represent values + * of variables to be correlated. + * + * Throws MathIllegalArgumentException if the matrix does not have at least + * two columns and two rows or if the array is not rectangular. Pairwise + * correlations are set to NaN if one of the correlates has zero variance. + * + * @param data matrix with columns representing variables to correlate + * @return correlation matrix + * @throws MathIllegalArgumentException if the array does not contain sufficient data + * @see #correlation(double[], double[]) + */ + public RealMatrix computeCorrelationMatrix(double[][] data) { + return computeCorrelationMatrix(new BlockRealMatrix(data)); + } + + /** + * Computes the Pearson's product-moment correlation coefficient between two arrays. + * + * <p>Throws MathIllegalArgumentException if the arrays do not have the same length + * or their common length is less than 2. Returns {@code NaN} if either of the arrays + * has zero variance (i.e., if one of the arrays does not contain at least two distinct + * values).</p> + * + * @param xArray first data array + * @param yArray second data array + * @return Returns Pearson's correlation coefficient for the two arrays + * @throws DimensionMismatchException if the arrays lengths do not match + * @throws MathIllegalArgumentException if there is insufficient data + */ + public double correlation(final double[] xArray, final double[] yArray) { + SimpleRegression regression = new SimpleRegression(); + if (xArray.length != yArray.length) { + throw new DimensionMismatchException(xArray.length, yArray.length); + } else if (xArray.length < 2) { + throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION, + xArray.length, 2); + } else { + for(int i=0; i<xArray.length; i++) { + regression.addData(xArray[i], yArray[i]); + } + return regression.getR(); + } + } + + /** + * Derives a correlation matrix from a covariance matrix. + * + * <p>Uses the formula <br/> + * <code>r(X,Y) = cov(X,Y)/s(X)s(Y)</code> where + * <code>r(·,·)</code> is the correlation coefficient and + * <code>s(·)</code> means standard deviation.</p> + * + * @param covarianceMatrix the covariance matrix + * @return correlation matrix + */ + public RealMatrix covarianceToCorrelation(RealMatrix covarianceMatrix) { + int nVars = covarianceMatrix.getColumnDimension(); + RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); + for (int i = 0; i < nVars; i++) { + double sigma = FastMath.sqrt(covarianceMatrix.getEntry(i, i)); + outMatrix.setEntry(i, i, 1d); + for (int j = 0; j < i; j++) { + double entry = covarianceMatrix.getEntry(i, j) / + (sigma * FastMath.sqrt(covarianceMatrix.getEntry(j, j))); + outMatrix.setEntry(i, j, entry); + outMatrix.setEntry(j, i, entry); + } + } + return outMatrix; + } + + /** + * Throws MathIllegalArgumentException if the matrix does not have at least + * two columns and two rows. + * + * @param matrix matrix to check for sufficiency + * @throws MathIllegalArgumentException if there is insufficient data + */ + private void checkSufficientData(final RealMatrix matrix) { + int nRows = matrix.getRowDimension(); + int nCols = matrix.getColumnDimension(); + if (nRows < 2 || nCols < 2) { + throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_ROWS_AND_COLUMNS, + nRows, nCols); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/SpearmansCorrelation.java b/src/main/java/org/apache/commons/math3/stat/correlation/SpearmansCorrelation.java new file mode 100644 index 0000000..80c0a54 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/SpearmansCorrelation.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.correlation; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.stat.ranking.NaNStrategy; +import org.apache.commons.math3.stat.ranking.NaturalRanking; +import org.apache.commons.math3.stat.ranking.RankingAlgorithm; + +/** + * Spearman's rank correlation. This implementation performs a rank + * transformation on the input data and then computes {@link PearsonsCorrelation} + * on the ranked data. + * <p> + * By default, ranks are computed using {@link NaturalRanking} with default + * strategies for handling NaNs and ties in the data (NaNs maximal, ties averaged). + * The ranking algorithm can be set using a constructor argument. + * + * @since 2.0 + */ +public class SpearmansCorrelation { + + /** Input data */ + private final RealMatrix data; + + /** Ranking algorithm */ + private final RankingAlgorithm rankingAlgorithm; + + /** Rank correlation */ + private final PearsonsCorrelation rankCorrelation; + + /** + * Create a SpearmansCorrelation without data. + */ + public SpearmansCorrelation() { + this(new NaturalRanking()); + } + + /** + * Create a SpearmansCorrelation with the given ranking algorithm. + * <p> + * From version 4.0 onwards this constructor will throw an exception + * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy. + * + * @param rankingAlgorithm ranking algorithm + * @since 3.1 + */ + public SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm) { + data = null; + this.rankingAlgorithm = rankingAlgorithm; + rankCorrelation = null; + } + + /** + * Create a SpearmansCorrelation from the given data matrix. + * + * @param dataMatrix matrix of data with columns representing + * variables to correlate + */ + public SpearmansCorrelation(final RealMatrix dataMatrix) { + this(dataMatrix, new NaturalRanking()); + } + + /** + * Create a SpearmansCorrelation with the given input data matrix + * and ranking algorithm. + * <p> + * From version 4.0 onwards this constructor will throw an exception + * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy. + * + * @param dataMatrix matrix of data with columns representing + * variables to correlate + * @param rankingAlgorithm ranking algorithm + */ + public SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm) { + this.rankingAlgorithm = rankingAlgorithm; + this.data = rankTransform(dataMatrix); + rankCorrelation = new PearsonsCorrelation(data); + } + + /** + * Calculate the Spearman Rank Correlation Matrix. + * + * @return Spearman Rank Correlation Matrix + * @throws NullPointerException if this instance was created with no data + */ + public RealMatrix getCorrelationMatrix() { + return rankCorrelation.getCorrelationMatrix(); + } + + /** + * Returns a {@link PearsonsCorrelation} instance constructed from the + * ranked input data. That is, + * <code>new SpearmansCorrelation(matrix).getRankCorrelation()</code> + * is equivalent to + * <code>new PearsonsCorrelation(rankTransform(matrix))</code> where + * <code>rankTransform(matrix)</code> is the result of applying the + * configured <code>RankingAlgorithm</code> to each of the columns of + * <code>matrix.</code> + * + * <p>Returns null if this instance was created with no data.</p> + * + * @return PearsonsCorrelation among ranked column data + */ + public PearsonsCorrelation getRankCorrelation() { + return rankCorrelation; + } + + /** + * Computes the Spearman's rank correlation matrix for the columns of the + * input matrix. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { + final RealMatrix matrixCopy = rankTransform(matrix); + return new PearsonsCorrelation().computeCorrelationMatrix(matrixCopy); + } + + /** + * Computes the Spearman's rank correlation matrix for the columns of the + * input rectangular array. The columns of the array represent values + * of variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final double[][] matrix) { + return computeCorrelationMatrix(new BlockRealMatrix(matrix)); + } + + /** + * Computes the Spearman's rank correlation coefficient between the two arrays. + * + * @param xArray first data array + * @param yArray second data array + * @return Returns Spearman's rank correlation coefficient for the two arrays + * @throws DimensionMismatchException if the arrays lengths do not match + * @throws MathIllegalArgumentException if the array length is less than 2 + */ + public double correlation(final double[] xArray, final double[] yArray) { + if (xArray.length != yArray.length) { + throw new DimensionMismatchException(xArray.length, yArray.length); + } else if (xArray.length < 2) { + throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION, + xArray.length, 2); + } else { + double[] x = xArray; + double[] y = yArray; + if (rankingAlgorithm instanceof NaturalRanking && + NaNStrategy.REMOVED == ((NaturalRanking) rankingAlgorithm).getNanStrategy()) { + final Set<Integer> nanPositions = new HashSet<Integer>(); + + nanPositions.addAll(getNaNPositions(xArray)); + nanPositions.addAll(getNaNPositions(yArray)); + + x = removeValues(xArray, nanPositions); + y = removeValues(yArray, nanPositions); + } + return new PearsonsCorrelation().correlation(rankingAlgorithm.rank(x), rankingAlgorithm.rank(y)); + } + } + + /** + * Applies rank transform to each of the columns of <code>matrix</code> + * using the current <code>rankingAlgorithm</code>. + * + * @param matrix matrix to transform + * @return a rank-transformed matrix + */ + private RealMatrix rankTransform(final RealMatrix matrix) { + RealMatrix transformed = null; + + if (rankingAlgorithm instanceof NaturalRanking && + ((NaturalRanking) rankingAlgorithm).getNanStrategy() == NaNStrategy.REMOVED) { + final Set<Integer> nanPositions = new HashSet<Integer>(); + for (int i = 0; i < matrix.getColumnDimension(); i++) { + nanPositions.addAll(getNaNPositions(matrix.getColumn(i))); + } + + // if we have found NaN values, we have to update the matrix size + if (!nanPositions.isEmpty()) { + transformed = new BlockRealMatrix(matrix.getRowDimension() - nanPositions.size(), + matrix.getColumnDimension()); + for (int i = 0; i < transformed.getColumnDimension(); i++) { + transformed.setColumn(i, removeValues(matrix.getColumn(i), nanPositions)); + } + } + } + + if (transformed == null) { + transformed = matrix.copy(); + } + + for (int i = 0; i < transformed.getColumnDimension(); i++) { + transformed.setColumn(i, rankingAlgorithm.rank(transformed.getColumn(i))); + } + + return transformed; + } + + /** + * Returns a list containing the indices of NaN values in the input array. + * + * @param input the input array + * @return a list of NaN positions in the input array + */ + private List<Integer> getNaNPositions(final double[] input) { + final List<Integer> positions = new ArrayList<Integer>(); + for (int i = 0; i < input.length; i++) { + if (Double.isNaN(input[i])) { + positions.add(i); + } + } + return positions; + } + + /** + * Removes all values from the input array at the specified indices. + * + * @param input the input array + * @param indices a set containing the indices to be removed + * @return the input array without the values at the specified indices + */ + private double[] removeValues(final double[] input, final Set<Integer> indices) { + if (indices.isEmpty()) { + return input; + } + final double[] result = new double[input.length - indices.size()]; + for (int i = 0, j = 0; i < input.length; i++) { + if (!indices.contains(i)) { + result[j++] = input[i]; + } + } + return result; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/StorelessBivariateCovariance.java b/src/main/java/org/apache/commons/math3/stat/correlation/StorelessBivariateCovariance.java new file mode 100644 index 0000000..1a798d2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/StorelessBivariateCovariance.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; + +/** + * Bivariate Covariance implementation that does not require input data to be + * stored in memory. + * + * <p>This class is based on a paper written by Philippe Pébay: + * <a href="http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf"> + * Formulas for Robust, One-Pass Parallel Computation of Covariances and + * Arbitrary-Order Statistical Moments</a>, 2008, Technical Report SAND2008-6212, + * Sandia National Laboratories. It computes the covariance for a pair of variables. + * Use {@link StorelessCovariance} to estimate an entire covariance matrix.</p> + * + * <p>Note: This class is package private as it is only used internally in + * the {@link StorelessCovariance} class.</p> + * + * @since 3.0 + */ +class StorelessBivariateCovariance { + + /** the mean of variable x */ + private double meanX; + + /** the mean of variable y */ + private double meanY; + + /** number of observations */ + private double n; + + /** the running covariance estimate */ + private double covarianceNumerator; + + /** flag for bias correction */ + private boolean biasCorrected; + + /** + * Create an empty {@link StorelessBivariateCovariance} instance with + * bias correction. + */ + StorelessBivariateCovariance() { + this(true); + } + + /** + * Create an empty {@link StorelessBivariateCovariance} instance. + * + * @param biasCorrection if <code>true</code> the covariance estimate is corrected + * for bias, i.e. n-1 in the denominator, otherwise there is no bias correction, + * i.e. n in the denominator. + */ + StorelessBivariateCovariance(final boolean biasCorrection) { + meanX = meanY = 0.0; + n = 0; + covarianceNumerator = 0.0; + biasCorrected = biasCorrection; + } + + /** + * Update the covariance estimation with a pair of variables (x, y). + * + * @param x the x value + * @param y the y value + */ + public void increment(final double x, final double y) { + n++; + final double deltaX = x - meanX; + final double deltaY = y - meanY; + meanX += deltaX / n; + meanY += deltaY / n; + covarianceNumerator += ((n - 1.0) / n) * deltaX * deltaY; + } + + /** + * Appends another bivariate covariance calculation to this. + * After this operation, statistics returned should be close to what would + * have been obtained by by performing all of the {@link #increment(double, double)} + * operations in {@code cov} directly on this. + * + * @param cov StorelessBivariateCovariance instance to append. + */ + public void append(StorelessBivariateCovariance cov) { + double oldN = n; + n += cov.n; + final double deltaX = cov.meanX - meanX; + final double deltaY = cov.meanY - meanY; + meanX += deltaX * cov.n / n; + meanY += deltaY * cov.n / n; + covarianceNumerator += cov.covarianceNumerator + oldN * cov.n / n * deltaX * deltaY; + } + + /** + * Returns the number of observations. + * + * @return number of observations + */ + public double getN() { + return n; + } + + /** + * Return the current covariance estimate. + * + * @return the current covariance + * @throws NumberIsTooSmallException if the number of observations + * is < 2 + */ + public double getResult() throws NumberIsTooSmallException { + if (n < 2) { + throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_DIMENSION, + n, 2, true); + } + if (biasCorrected) { + return covarianceNumerator / (n - 1d); + } else { + return covarianceNumerator / n; + } + } +} + diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/StorelessCovariance.java b/src/main/java/org/apache/commons/math3/stat/correlation/StorelessCovariance.java new file mode 100644 index 0000000..7e927ca --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/StorelessCovariance.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathUnsupportedOperationException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; + +/** + * Covariance implementation that does not require input data to be + * stored in memory. The size of the covariance matrix is specified in the + * constructor. Specific elements of the matrix are incrementally updated with + * calls to incrementRow() or increment Covariance(). + * + * <p>This class is based on a paper written by Philippe Pébay: + * <a href="http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf"> + * Formulas for Robust, One-Pass Parallel Computation of Covariances and + * Arbitrary-Order Statistical Moments</a>, 2008, Technical Report SAND2008-6212, + * Sandia National Laboratories.</p> + * + * <p>Note: the underlying covariance matrix is symmetric, thus only the + * upper triangular part of the matrix is stored and updated each increment.</p> + * + * @since 3.0 + */ +public class StorelessCovariance extends Covariance { + + /** the square covariance matrix (upper triangular part) */ + private StorelessBivariateCovariance[] covMatrix; + + /** dimension of the square covariance matrix */ + private int dimension; + + /** + * Create a bias corrected covariance matrix with a given dimension. + * + * @param dim the dimension of the square covariance matrix + */ + public StorelessCovariance(final int dim) { + this(dim, true); + } + + /** + * Create a covariance matrix with a given number of rows and columns and the + * indicated bias correction. + * + * @param dim the dimension of the covariance matrix + * @param biasCorrected if <code>true</code> the covariance estimate is corrected + * for bias, i.e. n-1 in the denominator, otherwise there is no bias correction, + * i.e. n in the denominator. + */ + public StorelessCovariance(final int dim, final boolean biasCorrected) { + dimension = dim; + covMatrix = new StorelessBivariateCovariance[dimension * (dimension + 1) / 2]; + initializeMatrix(biasCorrected); + } + + /** + * Initialize the internal two-dimensional array of + * {@link StorelessBivariateCovariance} instances. + * + * @param biasCorrected if the covariance estimate shall be corrected for bias + */ + private void initializeMatrix(final boolean biasCorrected) { + for(int i = 0; i < dimension; i++){ + for(int j = 0; j < dimension; j++){ + setElement(i, j, new StorelessBivariateCovariance(biasCorrected)); + } + } + } + + /** + * Returns the index (i, j) translated into the one-dimensional + * array used to store the upper triangular part of the symmetric + * covariance matrix. + * + * @param i the row index + * @param j the column index + * @return the corresponding index in the matrix array + */ + private int indexOf(final int i, final int j) { + return j < i ? i * (i + 1) / 2 + j : j * (j + 1) / 2 + i; + } + + /** + * Gets the element at index (i, j) from the covariance matrix + * @param i the row index + * @param j the column index + * @return the {@link StorelessBivariateCovariance} element at the given index + */ + private StorelessBivariateCovariance getElement(final int i, final int j) { + return covMatrix[indexOf(i, j)]; + } + + /** + * Sets the covariance element at index (i, j) in the covariance matrix + * @param i the row index + * @param j the column index + * @param cov the {@link StorelessBivariateCovariance} element to be set + */ + private void setElement(final int i, final int j, + final StorelessBivariateCovariance cov) { + covMatrix[indexOf(i, j)] = cov; + } + + /** + * Get the covariance for an individual element of the covariance matrix. + * + * @param xIndex row index in the covariance matrix + * @param yIndex column index in the covariance matrix + * @return the covariance of the given element + * @throws NumberIsTooSmallException if the number of observations + * in the cell is < 2 + */ + public double getCovariance(final int xIndex, + final int yIndex) + throws NumberIsTooSmallException { + + return getElement(xIndex, yIndex).getResult(); + + } + + /** + * Increment the covariance matrix with one row of data. + * + * @param data array representing one row of data. + * @throws DimensionMismatchException if the length of <code>rowData</code> + * does not match with the covariance matrix + */ + public void increment(final double[] data) + throws DimensionMismatchException { + + int length = data.length; + if (length != dimension) { + throw new DimensionMismatchException(length, dimension); + } + + // only update the upper triangular part of the covariance matrix + // as only these parts are actually stored + for (int i = 0; i < length; i++){ + for (int j = i; j < length; j++){ + getElement(i, j).increment(data[i], data[j]); + } + } + + } + + /** + * Appends {@code sc} to this, effectively aggregating the computations in {@code sc} + * with this. After invoking this method, covariances returned should be close + * to what would have been obtained by performing all of the {@link #increment(double[])} + * operations in {@code sc} directly on this. + * + * @param sc externally computed StorelessCovariance to add to this + * @throws DimensionMismatchException if the dimension of sc does not match this + * @since 3.3 + */ + public void append(StorelessCovariance sc) throws DimensionMismatchException { + if (sc.dimension != dimension) { + throw new DimensionMismatchException(sc.dimension, dimension); + } + + // only update the upper triangular part of the covariance matrix + // as only these parts are actually stored + for (int i = 0; i < dimension; i++) { + for (int j = i; j < dimension; j++) { + getElement(i, j).append(sc.getElement(i, j)); + } + } + } + + /** + * {@inheritDoc} + * @throws NumberIsTooSmallException if the number of observations + * in a cell is < 2 + */ + @Override + public RealMatrix getCovarianceMatrix() throws NumberIsTooSmallException { + return MatrixUtils.createRealMatrix(getData()); + } + + /** + * Return the covariance matrix as two-dimensional array. + * + * @return a two-dimensional double array of covariance values + * @throws NumberIsTooSmallException if the number of observations + * for a cell is < 2 + */ + public double[][] getData() throws NumberIsTooSmallException { + final double[][] data = new double[dimension][dimension]; + for (int i = 0; i < dimension; i++) { + for (int j = 0; j < dimension; j++) { + data[i][j] = getElement(i, j).getResult(); + } + } + return data; + } + + /** + * This {@link Covariance} method is not supported by a {@link StorelessCovariance}, + * since the number of bivariate observations does not have to be the same for different + * pairs of covariates - i.e., N as defined in {@link Covariance#getN()} is undefined. + * + * @return nothing as this implementation always throws a + * {@link MathUnsupportedOperationException} + * @throws MathUnsupportedOperationException in all cases + */ + @Override + public int getN() + throws MathUnsupportedOperationException { + throw new MathUnsupportedOperationException(); + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/package-info.java b/src/main/java/org/apache/commons/math3/stat/correlation/package-info.java new file mode 100644 index 0000000..adf285e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Correlations/Covariance computations. + * + */ +package org.apache.commons.math3.stat.correlation; diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractStorelessUnivariateStatistic.java b/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractStorelessUnivariateStatistic.java new file mode 100644 index 0000000..4249994 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractStorelessUnivariateStatistic.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.Precision; + +/** + * + * Abstract implementation of the {@link StorelessUnivariateStatistic} interface. + * <p> + * Provides default <code>evaluate()</code> and <code>incrementAll(double[])</code> + * implementations.</p> + * <p> + * <strong>Note that these implementations are not synchronized.</strong></p> + * + */ +public abstract class AbstractStorelessUnivariateStatistic + extends AbstractUnivariateStatistic + implements StorelessUnivariateStatistic { + + /** + * This default implementation calls {@link #clear}, then invokes + * {@link #increment} in a loop over the the input array, and then uses + * {@link #getResult} to compute the return value. + * <p> + * Note that this implementation changes the internal state of the + * statistic. Its side effects are the same as invoking {@link #clear} and + * then {@link #incrementAll(double[])}.</p> + * <p> + * Implementations may override this method with a more efficient and + * possibly more accurate implementation that works directly with the + * input array.</p> + * <p> + * If the array is null, a MathIllegalArgumentException is thrown.</p> + * @param values input array + * @return the value of the statistic applied to the input array + * @throws MathIllegalArgumentException if values is null + * @see org.apache.commons.math3.stat.descriptive.UnivariateStatistic#evaluate(double[]) + */ + @Override + public double evaluate(final double[] values) throws MathIllegalArgumentException { + if (values == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + return evaluate(values, 0, values.length); + } + + /** + * This default implementation calls {@link #clear}, then invokes + * {@link #increment} in a loop over the specified portion of the input + * array, and then uses {@link #getResult} to compute the return value. + * <p> + * Note that this implementation changes the internal state of the + * statistic. Its side effects are the same as invoking {@link #clear} and + * then {@link #incrementAll(double[], int, int)}.</p> + * <p> + * Implementations may override this method with a more efficient and + * possibly more accurate implementation that works directly with the + * input array.</p> + * <p> + * If the array is null or the index parameters are not valid, an + * MathIllegalArgumentException is thrown.</p> + * @param values the input array + * @param begin the index of the first element to include + * @param length the number of elements to include + * @return the value of the statistic applied to the included array entries + * @throws MathIllegalArgumentException if the array is null or the indices are not valid + * @see org.apache.commons.math3.stat.descriptive.UnivariateStatistic#evaluate(double[], int, int) + */ + @Override + public double evaluate(final double[] values, final int begin, + final int length) throws MathIllegalArgumentException { + if (test(values, begin, length)) { + clear(); + incrementAll(values, begin, length); + } + return getResult(); + } + + /** + * {@inheritDoc} + */ + @Override + public abstract StorelessUnivariateStatistic copy(); + + /** + * {@inheritDoc} + */ + public abstract void clear(); + + /** + * {@inheritDoc} + */ + public abstract double getResult(); + + /** + * {@inheritDoc} + */ + public abstract void increment(final double d); + + /** + * This default implementation just calls {@link #increment} in a loop over + * the input array. + * <p> + * Throws IllegalArgumentException if the input values array is null.</p> + * + * @param values values to add + * @throws MathIllegalArgumentException if values is null + * @see org.apache.commons.math3.stat.descriptive.StorelessUnivariateStatistic#incrementAll(double[]) + */ + public void incrementAll(double[] values) throws MathIllegalArgumentException { + if (values == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + incrementAll(values, 0, values.length); + } + + /** + * This default implementation just calls {@link #increment} in a loop over + * the specified portion of the input array. + * <p> + * Throws IllegalArgumentException if the input values array is null.</p> + * + * @param values array holding values to add + * @param begin index of the first array element to add + * @param length number of array elements to add + * @throws MathIllegalArgumentException if values is null + * @see org.apache.commons.math3.stat.descriptive.StorelessUnivariateStatistic#incrementAll(double[], int, int) + */ + public void incrementAll(double[] values, int begin, int length) throws MathIllegalArgumentException { + if (test(values, begin, length)) { + int k = begin + length; + for (int i = begin; i < k; i++) { + increment(values[i]); + } + } + } + + /** + * Returns true iff <code>object</code> is an + * <code>AbstractStorelessUnivariateStatistic</code> returning the same + * values as this for <code>getResult()</code> and <code>getN()</code> + * @param object object to test equality against. + * @return true if object returns the same value as this + */ + @Override + public boolean equals(Object object) { + if (object == this ) { + return true; + } + if (object instanceof AbstractStorelessUnivariateStatistic == false) { + return false; + } + AbstractStorelessUnivariateStatistic stat = (AbstractStorelessUnivariateStatistic) object; + return Precision.equalsIncludingNaN(stat.getResult(), this.getResult()) && + Precision.equalsIncludingNaN(stat.getN(), this.getN()); + } + + /** + * Returns hash code based on getResult() and getN() + * + * @return hash code + */ + @Override + public int hashCode() { + return 31* (31 + MathUtils.hash(getResult())) + MathUtils.hash(getN()); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractUnivariateStatistic.java b/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractUnivariateStatistic.java new file mode 100644 index 0000000..9abe45a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/AbstractUnivariateStatistic.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.MathArrays; + +/** + * Abstract base class for all implementations of the + * {@link UnivariateStatistic} interface. + * <p> + * Provides a default implementation of <code>evaluate(double[]),</code> + * delegating to <code>evaluate(double[], int, int)</code> in the natural way. + * </p> + * <p> + * Also includes a <code>test</code> method that performs generic parameter + * validation for the <code>evaluate</code> methods.</p> + * + */ +public abstract class AbstractUnivariateStatistic + implements UnivariateStatistic { + + /** Stored data. */ + private double[] storedData; + + /** + * Set the data array. + * <p> + * The stored value is a copy of the parameter array, not the array itself. + * </p> + * @param values data array to store (may be null to remove stored data) + * @see #evaluate() + */ + public void setData(final double[] values) { + storedData = (values == null) ? null : values.clone(); + } + + /** + * Get a copy of the stored data array. + * @return copy of the stored data array (may be null) + */ + public double[] getData() { + return (storedData == null) ? null : storedData.clone(); + } + + /** + * Get a reference to the stored data array. + * @return reference to the stored data array (may be null) + */ + protected double[] getDataRef() { + return storedData; + } + + /** + * Set the data array. The input array is copied, not referenced. + * + * @param values data array to store + * @param begin the index of the first element to include + * @param length the number of elements to include + * @throws MathIllegalArgumentException if values is null or the indices + * are not valid + * @see #evaluate() + */ + public void setData(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + if (values == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + + if (begin < 0) { + throw new NotPositiveException(LocalizedFormats.START_POSITION, begin); + } + + if (length < 0) { + throw new NotPositiveException(LocalizedFormats.LENGTH, length); + } + + if (begin + length > values.length) { + throw new NumberIsTooLargeException(LocalizedFormats.SUBARRAY_ENDS_AFTER_ARRAY_END, + begin + length, values.length, true); + } + storedData = new double[length]; + System.arraycopy(values, begin, storedData, 0, length); + } + + /** + * Returns the result of evaluating the statistic over the stored data. + * <p> + * The stored array is the one which was set by previous calls to {@link #setData(double[])}. + * </p> + * @return the value of the statistic applied to the stored data + * @throws MathIllegalArgumentException if the stored data array is null + */ + public double evaluate() throws MathIllegalArgumentException { + return evaluate(storedData); + } + + /** + * {@inheritDoc} + */ + public double evaluate(final double[] values) throws MathIllegalArgumentException { + test(values, 0, 0); + return evaluate(values, 0, values.length); + } + + /** + * {@inheritDoc} + */ + public abstract double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException; + + /** + * {@inheritDoc} + */ + public abstract UnivariateStatistic copy(); + + /** + * This method is used by <code>evaluate(double[], int, int)</code> methods + * to verify that the input parameters designate a subarray of positive length. + * <p> + * <ul> + * <li>returns <code>true</code> iff the parameters designate a subarray of + * positive length</li> + * <li>throws <code>MathIllegalArgumentException</code> if the array is null or + * or the indices are invalid</li> + * <li>returns <code>false</li> if the array is non-null, but + * <code>length</code> is 0. + * </ul></p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return true if the parameters are valid and designate a subarray of positive length + * @throws MathIllegalArgumentException if the indices are invalid or the array is null + */ + protected boolean test( + final double[] values, + final int begin, + final int length) throws MathIllegalArgumentException { + return MathArrays.verifyValues(values, begin, length, false); + } + + /** + * This method is used by <code>evaluate(double[], int, int)</code> methods + * to verify that the input parameters designate a subarray of positive length. + * <p> + * <ul> + * <li>returns <code>true</code> iff the parameters designate a subarray of + * non-negative length</li> + * <li>throws <code>IllegalArgumentException</code> if the array is null or + * or the indices are invalid</li> + * <li>returns <code>false</li> if the array is non-null, but + * <code>length</code> is 0 unless <code>allowEmpty</code> is <code>true</code> + * </ul></p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @param allowEmpty if <code>true</code> then zero length arrays are allowed + * @return true if the parameters are valid + * @throws MathIllegalArgumentException if the indices are invalid or the array is null + * @since 3.0 + */ + protected boolean test(final double[] values, final int begin, + final int length, final boolean allowEmpty) throws MathIllegalArgumentException { + return MathArrays.verifyValues(values, begin, length, allowEmpty); + } + + /** + * This method is used by <code>evaluate(double[], double[], int, int)</code> methods + * to verify that the begin and length parameters designate a subarray of positive length + * and the weights are all non-negative, non-NaN, finite, and not all zero. + * <p> + * <ul> + * <li>returns <code>true</code> iff the parameters designate a subarray of + * positive length and the weights array contains legitimate values.</li> + * <li>throws <code>IllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li></ul> + * </li> + * <li>returns <code>false</li> if the array is non-null, but + * <code>length</code> is 0. + * </ul></p> + * + * @param values the input array + * @param weights the weights array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return true if the parameters are valid and designate a subarray of positive length + * @throws MathIllegalArgumentException if the indices are invalid or the array is null + * @since 2.1 + */ + protected boolean test( + final double[] values, + final double[] weights, + final int begin, + final int length) throws MathIllegalArgumentException { + return MathArrays.verifyValues(values, weights, begin, length, false); + } + + /** + * This method is used by <code>evaluate(double[], double[], int, int)</code> methods + * to verify that the begin and length parameters designate a subarray of positive length + * and the weights are all non-negative, non-NaN, finite, and not all zero. + * <p> + * <ul> + * <li>returns <code>true</code> iff the parameters designate a subarray of + * non-negative length and the weights array contains legitimate values.</li> + * <li>throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li></ul> + * </li> + * <li>returns <code>false</li> if the array is non-null, but + * <code>length</code> is 0 unless <code>allowEmpty</code> is <code>true</code>. + * </ul></p> + * + * @param values the input array. + * @param weights the weights array. + * @param begin index of the first array element to include. + * @param length the number of elements to include. + * @param allowEmpty if {@code true} than allow zero length arrays to pass. + * @return {@code true} if the parameters are valid. + * @throws NullArgumentException if either of the arrays are null + * @throws MathIllegalArgumentException if the array indices are not valid, + * the weights array contains NaN, infinite or negative elements, or there + * are no positive weights. + * @since 3.0 + */ + protected boolean test(final double[] values, final double[] weights, + final int begin, final int length, final boolean allowEmpty) throws MathIllegalArgumentException { + + return MathArrays.verifyValues(values, weights, begin, length, allowEmpty); + } +} + diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/AggregateSummaryStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/AggregateSummaryStatistics.java new file mode 100644 index 0000000..6ab3c33 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/AggregateSummaryStatistics.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.descriptive; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Iterator; + +import org.apache.commons.math3.exception.NullArgumentException; + +/** + * <p> + * An aggregator for {@code SummaryStatistics} from several data sets or + * data set partitions. In its simplest usage mode, the client creates an + * instance via the zero-argument constructor, then uses + * {@link #createContributingStatistics()} to obtain a {@code SummaryStatistics} + * for each individual data set / partition. The per-set statistics objects + * are used as normal, and at any time the aggregate statistics for all the + * contributors can be obtained from this object. + * </p><p> + * Clients with specialized requirements can use alternative constructors to + * control the statistics implementations and initial values used by the + * contributing and the internal aggregate {@code SummaryStatistics} objects. + * </p><p> + * A static {@link #aggregate(Collection)} method is also included that computes + * aggregate statistics directly from a Collection of SummaryStatistics instances. + * </p><p> + * When {@link #createContributingStatistics()} is used to create SummaryStatistics + * instances to be aggregated concurrently, the created instances' + * {@link SummaryStatistics#addValue(double)} methods must synchronize on the aggregating + * instance maintained by this class. In multithreaded environments, if the functionality + * provided by {@link #aggregate(Collection)} is adequate, that method should be used + * to avoid unnecessary computation and synchronization delays.</p> + * + * @since 2.0 + * + */ +public class AggregateSummaryStatistics implements StatisticalSummary, + Serializable { + + + /** Serializable version identifier */ + private static final long serialVersionUID = -8207112444016386906L; + + /** + * A SummaryStatistics serving as a prototype for creating SummaryStatistics + * contributing to this aggregate + */ + private final SummaryStatistics statisticsPrototype; + + /** + * The SummaryStatistics in which aggregate statistics are accumulated. + */ + private final SummaryStatistics statistics; + + /** + * Initializes a new AggregateSummaryStatistics with default statistics + * implementations. + * + */ + public AggregateSummaryStatistics() { + // No try-catch or throws NAE because arg is guaranteed non-null + this(new SummaryStatistics()); + } + + /** + * Initializes a new AggregateSummaryStatistics with the specified statistics + * object as a prototype for contributing statistics and for the internal + * aggregate statistics. This provides for customized statistics implementations + * to be used by contributing and aggregate statistics. + * + * @param prototypeStatistics a {@code SummaryStatistics} serving as a + * prototype both for the internal aggregate statistics and for + * contributing statistics obtained via the + * {@code createContributingStatistics()} method. Being a prototype + * means that other objects are initialized by copying this object's state. + * If {@code null}, a new, default statistics object is used. Any statistic + * values in the prototype are propagated to contributing statistics + * objects and (once) into these aggregate statistics. + * @throws NullArgumentException if prototypeStatistics is null + * @see #createContributingStatistics() + */ + public AggregateSummaryStatistics(SummaryStatistics prototypeStatistics) throws NullArgumentException { + this(prototypeStatistics, + prototypeStatistics == null ? null : new SummaryStatistics(prototypeStatistics)); + } + + /** + * Initializes a new AggregateSummaryStatistics with the specified statistics + * object as a prototype for contributing statistics and for the internal + * aggregate statistics. This provides for different statistics implementations + * to be used by contributing and aggregate statistics and for an initial + * state to be supplied for the aggregate statistics. + * + * @param prototypeStatistics a {@code SummaryStatistics} serving as a + * prototype both for the internal aggregate statistics and for + * contributing statistics obtained via the + * {@code createContributingStatistics()} method. Being a prototype + * means that other objects are initialized by copying this object's state. + * If {@code null}, a new, default statistics object is used. Any statistic + * values in the prototype are propagated to contributing statistics + * objects, but not into these aggregate statistics. + * @param initialStatistics a {@code SummaryStatistics} to serve as the + * internal aggregate statistics object. If {@code null}, a new, default + * statistics object is used. + * @see #createContributingStatistics() + */ + public AggregateSummaryStatistics(SummaryStatistics prototypeStatistics, + SummaryStatistics initialStatistics) { + this.statisticsPrototype = + (prototypeStatistics == null) ? new SummaryStatistics() : prototypeStatistics; + this.statistics = + (initialStatistics == null) ? new SummaryStatistics() : initialStatistics; + } + + /** + * {@inheritDoc}. This version returns the maximum over all the aggregated + * data. + * + * @see StatisticalSummary#getMax() + */ + public double getMax() { + synchronized (statistics) { + return statistics.getMax(); + } + } + + /** + * {@inheritDoc}. This version returns the mean of all the aggregated data. + * + * @see StatisticalSummary#getMean() + */ + public double getMean() { + synchronized (statistics) { + return statistics.getMean(); + } + } + + /** + * {@inheritDoc}. This version returns the minimum over all the aggregated + * data. + * + * @see StatisticalSummary#getMin() + */ + public double getMin() { + synchronized (statistics) { + return statistics.getMin(); + } + } + + /** + * {@inheritDoc}. This version returns a count of all the aggregated data. + * + * @see StatisticalSummary#getN() + */ + public long getN() { + synchronized (statistics) { + return statistics.getN(); + } + } + + /** + * {@inheritDoc}. This version returns the standard deviation of all the + * aggregated data. + * + * @see StatisticalSummary#getStandardDeviation() + */ + public double getStandardDeviation() { + synchronized (statistics) { + return statistics.getStandardDeviation(); + } + } + + /** + * {@inheritDoc}. This version returns a sum of all the aggregated data. + * + * @see StatisticalSummary#getSum() + */ + public double getSum() { + synchronized (statistics) { + return statistics.getSum(); + } + } + + /** + * {@inheritDoc}. This version returns the variance of all the aggregated + * data. + * + * @see StatisticalSummary#getVariance() + */ + public double getVariance() { + synchronized (statistics) { + return statistics.getVariance(); + } + } + + /** + * Returns the sum of the logs of all the aggregated data. + * + * @return the sum of logs + * @see SummaryStatistics#getSumOfLogs() + */ + public double getSumOfLogs() { + synchronized (statistics) { + return statistics.getSumOfLogs(); + } + } + + /** + * Returns the geometric mean of all the aggregated data. + * + * @return the geometric mean + * @see SummaryStatistics#getGeometricMean() + */ + public double getGeometricMean() { + synchronized (statistics) { + return statistics.getGeometricMean(); + } + } + + /** + * Returns the sum of the squares of all the aggregated data. + * + * @return The sum of squares + * @see SummaryStatistics#getSumsq() + */ + public double getSumsq() { + synchronized (statistics) { + return statistics.getSumsq(); + } + } + + /** + * Returns a statistic related to the Second Central Moment. Specifically, + * what is returned is the sum of squared deviations from the sample mean + * among the all of the aggregated data. + * + * @return second central moment statistic + * @see SummaryStatistics#getSecondMoment() + */ + public double getSecondMoment() { + synchronized (statistics) { + return statistics.getSecondMoment(); + } + } + + /** + * Return a {@link StatisticalSummaryValues} instance reporting current + * aggregate statistics. + * + * @return Current values of aggregate statistics + */ + public StatisticalSummary getSummary() { + synchronized (statistics) { + return new StatisticalSummaryValues(getMean(), getVariance(), getN(), + getMax(), getMin(), getSum()); + } + } + + /** + * Creates and returns a {@code SummaryStatistics} whose data will be + * aggregated with those of this {@code AggregateSummaryStatistics}. + * + * @return a {@code SummaryStatistics} whose data will be aggregated with + * those of this {@code AggregateSummaryStatistics}. The initial state + * is a copy of the configured prototype statistics. + */ + public SummaryStatistics createContributingStatistics() { + SummaryStatistics contributingStatistics + = new AggregatingSummaryStatistics(statistics); + + // No try - catch or advertising NAE because neither argument will ever be null + SummaryStatistics.copy(statisticsPrototype, contributingStatistics); + + return contributingStatistics; + } + + /** + * Computes aggregate summary statistics. This method can be used to combine statistics + * computed over partitions or subsamples - i.e., the StatisticalSummaryValues returned + * should contain the same values that would have been obtained by computing a single + * StatisticalSummary over the combined dataset. + * <p> + * Returns null if the collection is empty or null. + * </p> + * + * @param statistics collection of SummaryStatistics to aggregate + * @return summary statistics for the combined dataset + */ + public static StatisticalSummaryValues aggregate(Collection<? extends StatisticalSummary> statistics) { + if (statistics == null) { + return null; + } + Iterator<? extends StatisticalSummary> iterator = statistics.iterator(); + if (!iterator.hasNext()) { + return null; + } + StatisticalSummary current = iterator.next(); + long n = current.getN(); + double min = current.getMin(); + double sum = current.getSum(); + double max = current.getMax(); + double var = current.getVariance(); + double m2 = var * (n - 1d); + double mean = current.getMean(); + while (iterator.hasNext()) { + current = iterator.next(); + if (current.getMin() < min || Double.isNaN(min)) { + min = current.getMin(); + } + if (current.getMax() > max || Double.isNaN(max)) { + max = current.getMax(); + } + sum += current.getSum(); + final double oldN = n; + final double curN = current.getN(); + n += curN; + final double meanDiff = current.getMean() - mean; + mean = sum / n; + final double curM2 = current.getVariance() * (curN - 1d); + m2 = m2 + curM2 + meanDiff * meanDiff * oldN * curN / n; + } + final double variance; + if (n == 0) { + variance = Double.NaN; + } else if (n == 1) { + variance = 0d; + } else { + variance = m2 / (n - 1); + } + return new StatisticalSummaryValues(mean, variance, n, max, min, sum); + } + + /** + * A SummaryStatistics that also forwards all values added to it to a second + * {@code SummaryStatistics} for aggregation. + * + * @since 2.0 + */ + private static class AggregatingSummaryStatistics extends SummaryStatistics { + + /** + * The serialization version of this class + */ + private static final long serialVersionUID = 1L; + + /** + * An additional SummaryStatistics into which values added to these + * statistics (and possibly others) are aggregated + */ + private final SummaryStatistics aggregateStatistics; + + /** + * Initializes a new AggregatingSummaryStatistics with the specified + * aggregate statistics object + * + * @param aggregateStatistics a {@code SummaryStatistics} into which + * values added to this statistics object should be aggregated + */ + AggregatingSummaryStatistics(SummaryStatistics aggregateStatistics) { + this.aggregateStatistics = aggregateStatistics; + } + + /** + * {@inheritDoc}. This version adds the provided value to the configured + * aggregate after adding it to these statistics. + * + * @see SummaryStatistics#addValue(double) + */ + @Override + public void addValue(double value) { + super.addValue(value); + synchronized (aggregateStatistics) { + aggregateStatistics.addValue(value); + } + } + + /** + * Returns true iff <code>object</code> is a + * <code>SummaryStatistics</code> instance and all statistics have the + * same values as this. + * @param object the object to test equality against. + * @return true if object equals this + */ + @Override + public boolean equals(Object object) { + if (object == this) { + return true; + } + if (object instanceof AggregatingSummaryStatistics == false) { + return false; + } + AggregatingSummaryStatistics stat = (AggregatingSummaryStatistics)object; + return super.equals(stat) && + aggregateStatistics.equals(stat.aggregateStatistics); + } + + /** + * Returns hash code based on values of statistics + * @return hash code + */ + @Override + public int hashCode() { + return 123 + super.hashCode() + aggregateStatistics.hashCode(); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/DescriptiveStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/DescriptiveStatistics.java new file mode 100644 index 0000000..b215bc8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/DescriptiveStatistics.java @@ -0,0 +1,777 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.moment.GeometricMean; +import org.apache.commons.math3.stat.descriptive.moment.Kurtosis; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.Skewness; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.stat.descriptive.rank.Max; +import org.apache.commons.math3.stat.descriptive.rank.Min; +import org.apache.commons.math3.stat.descriptive.rank.Percentile; +import org.apache.commons.math3.stat.descriptive.summary.Sum; +import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.ResizableDoubleArray; +import org.apache.commons.math3.util.FastMath; + + +/** + * Maintains a dataset of values of a single variable and computes descriptive + * statistics based on stored data. The {@link #getWindowSize() windowSize} + * property sets a limit on the number of values that can be stored in the + * dataset. The default value, INFINITE_WINDOW, puts no limit on the size of + * the dataset. This value should be used with caution, as the backing store + * will grow without bound in this case. For very large datasets, + * {@link SummaryStatistics}, which does not store the dataset, should be used + * instead of this class. If <code>windowSize</code> is not INFINITE_WINDOW and + * more values are added than can be stored in the dataset, new values are + * added in a "rolling" manner, with new values replacing the "oldest" values + * in the dataset. + * + * <p>Note: this class is not threadsafe. Use + * {@link SynchronizedDescriptiveStatistics} if concurrent access from multiple + * threads is required.</p> + * + */ +public class DescriptiveStatistics implements StatisticalSummary, Serializable { + + /** + * Represents an infinite window size. When the {@link #getWindowSize()} + * returns this value, there is no limit to the number of data values + * that can be stored in the dataset. + */ + public static final int INFINITE_WINDOW = -1; + + /** Serialization UID */ + private static final long serialVersionUID = 4133067267405273064L; + + /** Name of the setQuantile method. */ + private static final String SET_QUANTILE_METHOD_NAME = "setQuantile"; + + /** hold the window size **/ + protected int windowSize = INFINITE_WINDOW; + + /** + * Stored data values + */ + private ResizableDoubleArray eDA = new ResizableDoubleArray(); + + /** Mean statistic implementation - can be reset by setter. */ + private UnivariateStatistic meanImpl = new Mean(); + + /** Geometric mean statistic implementation - can be reset by setter. */ + private UnivariateStatistic geometricMeanImpl = new GeometricMean(); + + /** Kurtosis statistic implementation - can be reset by setter. */ + private UnivariateStatistic kurtosisImpl = new Kurtosis(); + + /** Maximum statistic implementation - can be reset by setter. */ + private UnivariateStatistic maxImpl = new Max(); + + /** Minimum statistic implementation - can be reset by setter. */ + private UnivariateStatistic minImpl = new Min(); + + /** Percentile statistic implementation - can be reset by setter. */ + private UnivariateStatistic percentileImpl = new Percentile(); + + /** Skewness statistic implementation - can be reset by setter. */ + private UnivariateStatistic skewnessImpl = new Skewness(); + + /** Variance statistic implementation - can be reset by setter. */ + private UnivariateStatistic varianceImpl = new Variance(); + + /** Sum of squares statistic implementation - can be reset by setter. */ + private UnivariateStatistic sumsqImpl = new SumOfSquares(); + + /** Sum statistic implementation - can be reset by setter. */ + private UnivariateStatistic sumImpl = new Sum(); + + /** + * Construct a DescriptiveStatistics instance with an infinite window + */ + public DescriptiveStatistics() { + } + + /** + * Construct a DescriptiveStatistics instance with the specified window + * + * @param window the window size. + * @throws MathIllegalArgumentException if window size is less than 1 but + * not equal to {@link #INFINITE_WINDOW} + */ + public DescriptiveStatistics(int window) throws MathIllegalArgumentException { + setWindowSize(window); + } + + /** + * Construct a DescriptiveStatistics instance with an infinite window + * and the initial data values in double[] initialDoubleArray. + * If initialDoubleArray is null, then this constructor corresponds to + * DescriptiveStatistics() + * + * @param initialDoubleArray the initial double[]. + */ + public DescriptiveStatistics(double[] initialDoubleArray) { + if (initialDoubleArray != null) { + eDA = new ResizableDoubleArray(initialDoubleArray); + } + } + + /** + * Copy constructor. Construct a new DescriptiveStatistics instance that + * is a copy of original. + * + * @param original DescriptiveStatistics instance to copy + * @throws NullArgumentException if original is null + */ + public DescriptiveStatistics(DescriptiveStatistics original) throws NullArgumentException { + copy(original, this); + } + + /** + * Adds the value to the dataset. If the dataset is at the maximum size + * (i.e., the number of stored elements equals the currently configured + * windowSize), the first (oldest) element in the dataset is discarded + * to make room for the new value. + * + * @param v the value to be added + */ + public void addValue(double v) { + if (windowSize != INFINITE_WINDOW) { + if (getN() == windowSize) { + eDA.addElementRolling(v); + } else if (getN() < windowSize) { + eDA.addElement(v); + } + } else { + eDA.addElement(v); + } + } + + /** + * Removes the most recent value from the dataset. + * + * @throws MathIllegalStateException if there are no elements stored + */ + public void removeMostRecentValue() throws MathIllegalStateException { + try { + eDA.discardMostRecentElements(1); + } catch (MathIllegalArgumentException ex) { + throw new MathIllegalStateException(LocalizedFormats.NO_DATA); + } + } + + /** + * Replaces the most recently stored value with the given value. + * There must be at least one element stored to call this method. + * + * @param v the value to replace the most recent stored value + * @return replaced value + * @throws MathIllegalStateException if there are no elements stored + */ + public double replaceMostRecentValue(double v) throws MathIllegalStateException { + return eDA.substituteMostRecentElement(v); + } + + /** + * Returns the <a href="http://www.xycoon.com/arithmetic_mean.htm"> + * arithmetic mean </a> of the available values + * @return The mean or Double.NaN if no values have been added. + */ + public double getMean() { + return apply(meanImpl); + } + + /** + * Returns the <a href="http://www.xycoon.com/geometric_mean.htm"> + * geometric mean </a> of the available values. + * <p> + * See {@link GeometricMean} for details on the computing algorithm.</p> + * + * @return The geometricMean, Double.NaN if no values have been added, + * or if any negative values have been added. + */ + public double getGeometricMean() { + return apply(geometricMeanImpl); + } + + /** + * Returns the (sample) variance of the available values. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in + * the denominator). Use {@link #getPopulationVariance()} for the non-bias-corrected + * population variance.</p> + * + * @return The variance, Double.NaN if no values have been added + * or 0.0 for a single value set. + */ + public double getVariance() { + return apply(varianceImpl); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance"> + * population variance</a> of the available values. + * + * @return The population variance, Double.NaN if no values have been added, + * or 0.0 for a single value set. + */ + public double getPopulationVariance() { + return apply(new Variance(false)); + } + + /** + * Returns the standard deviation of the available values. + * @return The standard deviation, Double.NaN if no values have been added + * or 0.0 for a single value set. + */ + public double getStandardDeviation() { + double stdDev = Double.NaN; + if (getN() > 0) { + if (getN() > 1) { + stdDev = FastMath.sqrt(getVariance()); + } else { + stdDev = 0.0; + } + } + return stdDev; + } + + /** + * Returns the quadratic mean, a.k.a. + * <a href="http://mathworld.wolfram.com/Root-Mean-Square.html"> + * root-mean-square</a> of the available values + * @return The quadratic mean or {@code Double.NaN} if no values + * have been added. + */ + public double getQuadraticMean() { + final long n = getN(); + return n > 0 ? FastMath.sqrt(getSumsq() / n) : Double.NaN; + } + + /** + * Returns the skewness of the available values. Skewness is a + * measure of the asymmetry of a given distribution. + * + * @return The skewness, Double.NaN if less than 3 values have been added. + */ + public double getSkewness() { + return apply(skewnessImpl); + } + + /** + * Returns the Kurtosis of the available values. Kurtosis is a + * measure of the "peakedness" of a distribution. + * + * @return The kurtosis, Double.NaN if less than 4 values have been added. + */ + public double getKurtosis() { + return apply(kurtosisImpl); + } + + /** + * Returns the maximum of the available values + * @return The max or Double.NaN if no values have been added. + */ + public double getMax() { + return apply(maxImpl); + } + + /** + * Returns the minimum of the available values + * @return The min or Double.NaN if no values have been added. + */ + public double getMin() { + return apply(minImpl); + } + + /** + * Returns the number of available values + * @return The number of available values + */ + public long getN() { + return eDA.getNumElements(); + } + + /** + * Returns the sum of the values that have been added to Univariate. + * @return The sum or Double.NaN if no values have been added + */ + public double getSum() { + return apply(sumImpl); + } + + /** + * Returns the sum of the squares of the available values. + * @return The sum of the squares or Double.NaN if no + * values have been added. + */ + public double getSumsq() { + return apply(sumsqImpl); + } + + /** + * Resets all statistics and storage + */ + public void clear() { + eDA.clear(); + } + + + /** + * Returns the maximum number of values that can be stored in the + * dataset, or INFINITE_WINDOW (-1) if there is no limit. + * + * @return The current window size or -1 if its Infinite. + */ + public int getWindowSize() { + return windowSize; + } + + /** + * WindowSize controls the number of values that contribute to the + * reported statistics. For example, if windowSize is set to 3 and the + * values {1,2,3,4,5} have been added <strong> in that order</strong> then + * the <i>available values</i> are {3,4,5} and all reported statistics will + * be based on these values. If {@code windowSize} is decreased as a result + * of this call and there are more than the new value of elements in the + * current dataset, values from the front of the array are discarded to + * reduce the dataset to {@code windowSize} elements. + * + * @param windowSize sets the size of the window. + * @throws MathIllegalArgumentException if window size is less than 1 but + * not equal to {@link #INFINITE_WINDOW} + */ + public void setWindowSize(int windowSize) throws MathIllegalArgumentException { + if (windowSize < 1 && windowSize != INFINITE_WINDOW) { + throw new MathIllegalArgumentException( + LocalizedFormats.NOT_POSITIVE_WINDOW_SIZE, windowSize); + } + + this.windowSize = windowSize; + + // We need to check to see if we need to discard elements + // from the front of the array. If the windowSize is less than + // the current number of elements. + if (windowSize != INFINITE_WINDOW && windowSize < eDA.getNumElements()) { + eDA.discardFrontElements(eDA.getNumElements() - windowSize); + } + } + + /** + * Returns the current set of values in an array of double primitives. + * The order of addition is preserved. The returned array is a fresh + * copy of the underlying data -- i.e., it is not a reference to the + * stored data. + * + * @return returns the current set of numbers in the order in which they + * were added to this set + */ + public double[] getValues() { + return eDA.getElements(); + } + + /** + * Returns the current set of values in an array of double primitives, + * sorted in ascending order. The returned array is a fresh + * copy of the underlying data -- i.e., it is not a reference to the + * stored data. + * @return returns the current set of + * numbers sorted in ascending order + */ + public double[] getSortedValues() { + double[] sort = getValues(); + Arrays.sort(sort); + return sort; + } + + /** + * Returns the element at the specified index + * @param index The Index of the element + * @return return the element at the specified index + */ + public double getElement(int index) { + return eDA.getElement(index); + } + + /** + * Returns an estimate for the pth percentile of the stored values. + * <p> + * The implementation provided here follows the first estimation procedure presented + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section2/prc252.htm">here.</a> + * </p><p> + * <strong>Preconditions</strong>:<ul> + * <li><code>0 < p ≤ 100</code> (otherwise an + * <code>MathIllegalArgumentException</code> is thrown)</li> + * <li>at least one value must be stored (returns <code>Double.NaN + * </code> otherwise)</li> + * </ul></p> + * + * @param p the requested percentile (scaled from 0 - 100) + * @return An estimate for the pth percentile of the stored data + * @throws MathIllegalStateException if percentile implementation has been + * overridden and the supplied implementation does not support setQuantile + * @throws MathIllegalArgumentException if p is not a valid quantile + */ + public double getPercentile(double p) throws MathIllegalStateException, MathIllegalArgumentException { + if (percentileImpl instanceof Percentile) { + ((Percentile) percentileImpl).setQuantile(p); + } else { + try { + percentileImpl.getClass().getMethod(SET_QUANTILE_METHOD_NAME, + new Class[] {Double.TYPE}).invoke(percentileImpl, + new Object[] {Double.valueOf(p)}); + } catch (NoSuchMethodException e1) { // Setter guard should prevent + throw new MathIllegalStateException( + LocalizedFormats.PERCENTILE_IMPLEMENTATION_UNSUPPORTED_METHOD, + percentileImpl.getClass().getName(), SET_QUANTILE_METHOD_NAME); + } catch (IllegalAccessException e2) { + throw new MathIllegalStateException( + LocalizedFormats.PERCENTILE_IMPLEMENTATION_CANNOT_ACCESS_METHOD, + SET_QUANTILE_METHOD_NAME, percentileImpl.getClass().getName()); + } catch (InvocationTargetException e3) { + throw new IllegalStateException(e3.getCause()); + } + } + return apply(percentileImpl); + } + + /** + * Generates a text report displaying univariate statistics from values + * that have been added. Each statistic is displayed on a separate + * line. + * + * @return String with line feeds displaying statistics + */ + @Override + public String toString() { + StringBuilder outBuffer = new StringBuilder(); + String endl = "\n"; + outBuffer.append("DescriptiveStatistics:").append(endl); + outBuffer.append("n: ").append(getN()).append(endl); + outBuffer.append("min: ").append(getMin()).append(endl); + outBuffer.append("max: ").append(getMax()).append(endl); + outBuffer.append("mean: ").append(getMean()).append(endl); + outBuffer.append("std dev: ").append(getStandardDeviation()) + .append(endl); + try { + // No catch for MIAE because actual parameter is valid below + outBuffer.append("median: ").append(getPercentile(50)).append(endl); + } catch (MathIllegalStateException ex) { + outBuffer.append("median: unavailable").append(endl); + } + outBuffer.append("skewness: ").append(getSkewness()).append(endl); + outBuffer.append("kurtosis: ").append(getKurtosis()).append(endl); + return outBuffer.toString(); + } + + /** + * Apply the given statistic to the data associated with this set of statistics. + * @param stat the statistic to apply + * @return the computed value of the statistic. + */ + public double apply(UnivariateStatistic stat) { + // No try-catch or advertised exception here because arguments are guaranteed valid + return eDA.compute(stat); + } + + // Implementation getters and setter + + /** + * Returns the currently configured mean implementation. + * + * @return the UnivariateStatistic implementing the mean + * @since 1.2 + */ + public synchronized UnivariateStatistic getMeanImpl() { + return meanImpl; + } + + /** + * <p>Sets the implementation for the mean.</p> + * + * @param meanImpl the UnivariateStatistic instance to use + * for computing the mean + * @since 1.2 + */ + public synchronized void setMeanImpl(UnivariateStatistic meanImpl) { + this.meanImpl = meanImpl; + } + + /** + * Returns the currently configured geometric mean implementation. + * + * @return the UnivariateStatistic implementing the geometric mean + * @since 1.2 + */ + public synchronized UnivariateStatistic getGeometricMeanImpl() { + return geometricMeanImpl; + } + + /** + * <p>Sets the implementation for the gemoetric mean.</p> + * + * @param geometricMeanImpl the UnivariateStatistic instance to use + * for computing the geometric mean + * @since 1.2 + */ + public synchronized void setGeometricMeanImpl( + UnivariateStatistic geometricMeanImpl) { + this.geometricMeanImpl = geometricMeanImpl; + } + + /** + * Returns the currently configured kurtosis implementation. + * + * @return the UnivariateStatistic implementing the kurtosis + * @since 1.2 + */ + public synchronized UnivariateStatistic getKurtosisImpl() { + return kurtosisImpl; + } + + /** + * <p>Sets the implementation for the kurtosis.</p> + * + * @param kurtosisImpl the UnivariateStatistic instance to use + * for computing the kurtosis + * @since 1.2 + */ + public synchronized void setKurtosisImpl(UnivariateStatistic kurtosisImpl) { + this.kurtosisImpl = kurtosisImpl; + } + + /** + * Returns the currently configured maximum implementation. + * + * @return the UnivariateStatistic implementing the maximum + * @since 1.2 + */ + public synchronized UnivariateStatistic getMaxImpl() { + return maxImpl; + } + + /** + * <p>Sets the implementation for the maximum.</p> + * + * @param maxImpl the UnivariateStatistic instance to use + * for computing the maximum + * @since 1.2 + */ + public synchronized void setMaxImpl(UnivariateStatistic maxImpl) { + this.maxImpl = maxImpl; + } + + /** + * Returns the currently configured minimum implementation. + * + * @return the UnivariateStatistic implementing the minimum + * @since 1.2 + */ + public synchronized UnivariateStatistic getMinImpl() { + return minImpl; + } + + /** + * <p>Sets the implementation for the minimum.</p> + * + * @param minImpl the UnivariateStatistic instance to use + * for computing the minimum + * @since 1.2 + */ + public synchronized void setMinImpl(UnivariateStatistic minImpl) { + this.minImpl = minImpl; + } + + /** + * Returns the currently configured percentile implementation. + * + * @return the UnivariateStatistic implementing the percentile + * @since 1.2 + */ + public synchronized UnivariateStatistic getPercentileImpl() { + return percentileImpl; + } + + /** + * Sets the implementation to be used by {@link #getPercentile(double)}. + * The supplied <code>UnivariateStatistic</code> must provide a + * <code>setQuantile(double)</code> method; otherwise + * <code>IllegalArgumentException</code> is thrown. + * + * @param percentileImpl the percentileImpl to set + * @throws MathIllegalArgumentException if the supplied implementation does not + * provide a <code>setQuantile</code> method + * @since 1.2 + */ + public synchronized void setPercentileImpl(UnivariateStatistic percentileImpl) + throws MathIllegalArgumentException { + try { + percentileImpl.getClass().getMethod(SET_QUANTILE_METHOD_NAME, + new Class[] {Double.TYPE}).invoke(percentileImpl, + new Object[] {Double.valueOf(50.0d)}); + } catch (NoSuchMethodException e1) { + throw new MathIllegalArgumentException( + LocalizedFormats.PERCENTILE_IMPLEMENTATION_UNSUPPORTED_METHOD, + percentileImpl.getClass().getName(), SET_QUANTILE_METHOD_NAME); + } catch (IllegalAccessException e2) { + throw new MathIllegalArgumentException( + LocalizedFormats.PERCENTILE_IMPLEMENTATION_CANNOT_ACCESS_METHOD, + SET_QUANTILE_METHOD_NAME, percentileImpl.getClass().getName()); + } catch (InvocationTargetException e3) { + throw new IllegalArgumentException(e3.getCause()); + } + this.percentileImpl = percentileImpl; + } + + /** + * Returns the currently configured skewness implementation. + * + * @return the UnivariateStatistic implementing the skewness + * @since 1.2 + */ + public synchronized UnivariateStatistic getSkewnessImpl() { + return skewnessImpl; + } + + /** + * <p>Sets the implementation for the skewness.</p> + * + * @param skewnessImpl the UnivariateStatistic instance to use + * for computing the skewness + * @since 1.2 + */ + public synchronized void setSkewnessImpl( + UnivariateStatistic skewnessImpl) { + this.skewnessImpl = skewnessImpl; + } + + /** + * Returns the currently configured variance implementation. + * + * @return the UnivariateStatistic implementing the variance + * @since 1.2 + */ + public synchronized UnivariateStatistic getVarianceImpl() { + return varianceImpl; + } + + /** + * <p>Sets the implementation for the variance.</p> + * + * @param varianceImpl the UnivariateStatistic instance to use + * for computing the variance + * @since 1.2 + */ + public synchronized void setVarianceImpl( + UnivariateStatistic varianceImpl) { + this.varianceImpl = varianceImpl; + } + + /** + * Returns the currently configured sum of squares implementation. + * + * @return the UnivariateStatistic implementing the sum of squares + * @since 1.2 + */ + public synchronized UnivariateStatistic getSumsqImpl() { + return sumsqImpl; + } + + /** + * <p>Sets the implementation for the sum of squares.</p> + * + * @param sumsqImpl the UnivariateStatistic instance to use + * for computing the sum of squares + * @since 1.2 + */ + public synchronized void setSumsqImpl(UnivariateStatistic sumsqImpl) { + this.sumsqImpl = sumsqImpl; + } + + /** + * Returns the currently configured sum implementation. + * + * @return the UnivariateStatistic implementing the sum + * @since 1.2 + */ + public synchronized UnivariateStatistic getSumImpl() { + return sumImpl; + } + + /** + * <p>Sets the implementation for the sum.</p> + * + * @param sumImpl the UnivariateStatistic instance to use + * for computing the sum + * @since 1.2 + */ + public synchronized void setSumImpl(UnivariateStatistic sumImpl) { + this.sumImpl = sumImpl; + } + + /** + * Returns a copy of this DescriptiveStatistics instance with the same internal state. + * + * @return a copy of this + */ + public DescriptiveStatistics copy() { + DescriptiveStatistics result = new DescriptiveStatistics(); + // No try-catch or advertised exception because parms are guaranteed valid + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source DescriptiveStatistics to copy + * @param dest DescriptiveStatistics to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(DescriptiveStatistics source, DescriptiveStatistics dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + // Copy data and window size + dest.eDA = source.eDA.copy(); + dest.windowSize = source.windowSize; + + // Copy implementations + dest.maxImpl = source.maxImpl.copy(); + dest.meanImpl = source.meanImpl.copy(); + dest.minImpl = source.minImpl.copy(); + dest.sumImpl = source.sumImpl.copy(); + dest.varianceImpl = source.varianceImpl.copy(); + dest.sumsqImpl = source.sumsqImpl.copy(); + dest.geometricMeanImpl = source.geometricMeanImpl.copy(); + dest.kurtosisImpl = source.kurtosisImpl; + dest.skewnessImpl = source.skewnessImpl; + dest.percentileImpl = source.percentileImpl; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/MultivariateSummaryStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/MultivariateSummaryStatistics.java new file mode 100644 index 0000000..3ede26e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/MultivariateSummaryStatistics.java @@ -0,0 +1,635 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.stat.descriptive.moment.GeometricMean; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.VectorialCovariance; +import org.apache.commons.math3.stat.descriptive.rank.Max; +import org.apache.commons.math3.stat.descriptive.rank.Min; +import org.apache.commons.math3.stat.descriptive.summary.Sum; +import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs; +import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.Precision; +import org.apache.commons.math3.util.FastMath; + +/** + * <p>Computes summary statistics for a stream of n-tuples added using the + * {@link #addValue(double[]) addValue} method. The data values are not stored + * in memory, so this class can be used to compute statistics for very large + * n-tuple streams.</p> + * + * <p>The {@link StorelessUnivariateStatistic} instances used to maintain + * summary state and compute statistics are configurable via setters. + * For example, the default implementation for the mean can be overridden by + * calling {@link #setMeanImpl(StorelessUnivariateStatistic[])}. Actual + * parameters to these methods must implement the + * {@link StorelessUnivariateStatistic} interface and configuration must be + * completed before <code>addValue</code> is called. No configuration is + * necessary to use the default, commons-math provided implementations.</p> + * + * <p>To compute statistics for a stream of n-tuples, construct a + * MultivariateStatistics instance with dimension n and then use + * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code> + * methods where Xxx is a statistic return an array of <code>double</code> + * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element is the + * value of the given statistic for data range consisting of the i<sup>th</sup> element of + * each of the input n-tuples. For example, if <code>addValue</code> is called + * with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8}, + * <code>getSum</code> will return a three-element array with values + * {0+3+6, 1+4+7, 2+5+8}</p> + * + * <p>Note: This class is not thread-safe. Use + * {@link SynchronizedMultivariateSummaryStatistics} if concurrent access from multiple + * threads is required.</p> + * + * @since 1.2 + */ +public class MultivariateSummaryStatistics + implements StatisticalMultivariateSummary, Serializable { + + /** Serialization UID */ + private static final long serialVersionUID = 2271900808994826718L; + + /** Dimension of the data. */ + private int k; + + /** Count of values that have been added */ + private long n = 0; + + /** Sum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] sumImpl; + + /** Sum of squares statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] sumSqImpl; + + /** Minimum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] minImpl; + + /** Maximum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] maxImpl; + + /** Sum of log statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] sumLogImpl; + + /** Geometric mean statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] geoMeanImpl; + + /** Mean statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic[] meanImpl; + + /** Covariance statistic implementation - cannot be reset. */ + private VectorialCovariance covarianceImpl; + + /** + * Construct a MultivariateSummaryStatistics instance + * @param k dimension of the data + * @param isCovarianceBiasCorrected if true, the unbiased sample + * covariance is computed, otherwise the biased population covariance + * is computed + */ + public MultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) { + this.k = k; + + sumImpl = new StorelessUnivariateStatistic[k]; + sumSqImpl = new StorelessUnivariateStatistic[k]; + minImpl = new StorelessUnivariateStatistic[k]; + maxImpl = new StorelessUnivariateStatistic[k]; + sumLogImpl = new StorelessUnivariateStatistic[k]; + geoMeanImpl = new StorelessUnivariateStatistic[k]; + meanImpl = new StorelessUnivariateStatistic[k]; + + for (int i = 0; i < k; ++i) { + sumImpl[i] = new Sum(); + sumSqImpl[i] = new SumOfSquares(); + minImpl[i] = new Min(); + maxImpl[i] = new Max(); + sumLogImpl[i] = new SumOfLogs(); + geoMeanImpl[i] = new GeometricMean(); + meanImpl[i] = new Mean(); + } + + covarianceImpl = + new VectorialCovariance(k, isCovarianceBiasCorrected); + + } + + /** + * Add an n-tuple to the data + * + * @param value the n-tuple to add + * @throws DimensionMismatchException if the length of the array + * does not match the one used at construction + */ + public void addValue(double[] value) throws DimensionMismatchException { + checkDimension(value.length); + for (int i = 0; i < k; ++i) { + double v = value[i]; + sumImpl[i].increment(v); + sumSqImpl[i].increment(v); + minImpl[i].increment(v); + maxImpl[i].increment(v); + sumLogImpl[i].increment(v); + geoMeanImpl[i].increment(v); + meanImpl[i].increment(v); + } + covarianceImpl.increment(value); + n++; + } + + /** + * Returns the dimension of the data + * @return The dimension of the data + */ + public int getDimension() { + return k; + } + + /** + * Returns the number of available values + * @return The number of available values + */ + public long getN() { + return n; + } + + /** + * Returns an array of the results of a statistic. + * @param stats univariate statistic array + * @return results array + */ + private double[] getResults(StorelessUnivariateStatistic[] stats) { + double[] results = new double[stats.length]; + for (int i = 0; i < results.length; ++i) { + results[i] = stats[i].getResult(); + } + return results; + } + + /** + * Returns an array whose i<sup>th</sup> entry is the sum of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component sums + */ + public double[] getSum() { + return getResults(sumImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the sum of squares of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component sums of squares + */ + public double[] getSumSq() { + return getResults(sumSqImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the sum of logs of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component log sums + */ + public double[] getSumLog() { + return getResults(sumLogImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the mean of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component means + */ + public double[] getMean() { + return getResults(meanImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the standard deviation of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component standard deviations + */ + public double[] getStandardDeviation() { + double[] stdDev = new double[k]; + if (getN() < 1) { + Arrays.fill(stdDev, Double.NaN); + } else if (getN() < 2) { + Arrays.fill(stdDev, 0.0); + } else { + RealMatrix matrix = covarianceImpl.getResult(); + for (int i = 0; i < k; ++i) { + stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i)); + } + } + return stdDev; + } + + /** + * Returns the covariance matrix of the values that have been added. + * + * @return the covariance matrix + */ + public RealMatrix getCovariance() { + return covarianceImpl.getResult(); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the maximum of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component maxima + */ + public double[] getMax() { + return getResults(maxImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the minimum of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component minima + */ + public double[] getMin() { + return getResults(minImpl); + } + + /** + * Returns an array whose i<sup>th</sup> entry is the geometric mean of the + * i<sup>th</sup> entries of the arrays that have been added using + * {@link #addValue(double[])} + * + * @return the array of component geometric means + */ + public double[] getGeometricMean() { + return getResults(geoMeanImpl); + } + + /** + * Generates a text report displaying + * summary statistics from values that + * have been added. + * @return String with line feeds displaying statistics + */ + @Override + public String toString() { + final String separator = ", "; + final String suffix = System.getProperty("line.separator"); + StringBuilder outBuffer = new StringBuilder(); + outBuffer.append("MultivariateSummaryStatistics:" + suffix); + outBuffer.append("n: " + getN() + suffix); + append(outBuffer, getMin(), "min: ", separator, suffix); + append(outBuffer, getMax(), "max: ", separator, suffix); + append(outBuffer, getMean(), "mean: ", separator, suffix); + append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix); + append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix); + append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix); + append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix); + outBuffer.append("covariance: " + getCovariance().toString() + suffix); + return outBuffer.toString(); + } + + /** + * Append a text representation of an array to a buffer. + * @param buffer buffer to fill + * @param data data array + * @param prefix text prefix + * @param separator elements separator + * @param suffix text suffix + */ + private void append(StringBuilder buffer, double[] data, + String prefix, String separator, String suffix) { + buffer.append(prefix); + for (int i = 0; i < data.length; ++i) { + if (i > 0) { + buffer.append(separator); + } + buffer.append(data[i]); + } + buffer.append(suffix); + } + + /** + * Resets all statistics and storage + */ + public void clear() { + this.n = 0; + for (int i = 0; i < k; ++i) { + minImpl[i].clear(); + maxImpl[i].clear(); + sumImpl[i].clear(); + sumLogImpl[i].clear(); + sumSqImpl[i].clear(); + geoMeanImpl[i].clear(); + meanImpl[i].clear(); + } + covarianceImpl.clear(); + } + + /** + * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code> + * instance and all statistics have the same values as this. + * @param object the object to test equality against. + * @return true if object equals this + */ + @Override + public boolean equals(Object object) { + if (object == this ) { + return true; + } + if (object instanceof MultivariateSummaryStatistics == false) { + return false; + } + MultivariateSummaryStatistics stat = (MultivariateSummaryStatistics) object; + return MathArrays.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) && + MathArrays.equalsIncludingNaN(stat.getMax(), getMax()) && + MathArrays.equalsIncludingNaN(stat.getMean(), getMean()) && + MathArrays.equalsIncludingNaN(stat.getMin(), getMin()) && + Precision.equalsIncludingNaN(stat.getN(), getN()) && + MathArrays.equalsIncludingNaN(stat.getSum(), getSum()) && + MathArrays.equalsIncludingNaN(stat.getSumSq(), getSumSq()) && + MathArrays.equalsIncludingNaN(stat.getSumLog(), getSumLog()) && + stat.getCovariance().equals( getCovariance()); + } + + /** + * Returns hash code based on values of statistics + * + * @return hash code + */ + @Override + public int hashCode() { + int result = 31 + MathUtils.hash(getGeometricMean()); + result = result * 31 + MathUtils.hash(getGeometricMean()); + result = result * 31 + MathUtils.hash(getMax()); + result = result * 31 + MathUtils.hash(getMean()); + result = result * 31 + MathUtils.hash(getMin()); + result = result * 31 + MathUtils.hash(getN()); + result = result * 31 + MathUtils.hash(getSum()); + result = result * 31 + MathUtils.hash(getSumSq()); + result = result * 31 + MathUtils.hash(getSumLog()); + result = result * 31 + getCovariance().hashCode(); + return result; + } + + // Getters and setters for statistics implementations + /** + * Sets statistics implementations. + * @param newImpl new implementations for statistics + * @param oldImpl old implementations for statistics + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e. if n > 0) + */ + private void setImpl(StorelessUnivariateStatistic[] newImpl, + StorelessUnivariateStatistic[] oldImpl) throws MathIllegalStateException, + DimensionMismatchException { + checkEmpty(); + checkDimension(newImpl.length); + System.arraycopy(newImpl, 0, oldImpl, 0, newImpl.length); + } + + /** + * Returns the currently configured Sum implementation + * + * @return the StorelessUnivariateStatistic implementing the sum + */ + public StorelessUnivariateStatistic[] getSumImpl() { + return sumImpl.clone(); + } + + /** + * <p>Sets the implementation for the Sum.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param sumImpl the StorelessUnivariateStatistic instance to use + * for computing the Sum + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setSumImpl(StorelessUnivariateStatistic[] sumImpl) + throws MathIllegalStateException, DimensionMismatchException { + setImpl(sumImpl, this.sumImpl); + } + + /** + * Returns the currently configured sum of squares implementation + * + * @return the StorelessUnivariateStatistic implementing the sum of squares + */ + public StorelessUnivariateStatistic[] getSumsqImpl() { + return sumSqImpl.clone(); + } + + /** + * <p>Sets the implementation for the sum of squares.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param sumsqImpl the StorelessUnivariateStatistic instance to use + * for computing the sum of squares + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl) + throws MathIllegalStateException, DimensionMismatchException { + setImpl(sumsqImpl, this.sumSqImpl); + } + + /** + * Returns the currently configured minimum implementation + * + * @return the StorelessUnivariateStatistic implementing the minimum + */ + public StorelessUnivariateStatistic[] getMinImpl() { + return minImpl.clone(); + } + + /** + * <p>Sets the implementation for the minimum.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param minImpl the StorelessUnivariateStatistic instance to use + * for computing the minimum + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setMinImpl(StorelessUnivariateStatistic[] minImpl) + throws MathIllegalStateException, DimensionMismatchException { + setImpl(minImpl, this.minImpl); + } + + /** + * Returns the currently configured maximum implementation + * + * @return the StorelessUnivariateStatistic implementing the maximum + */ + public StorelessUnivariateStatistic[] getMaxImpl() { + return maxImpl.clone(); + } + + /** + * <p>Sets the implementation for the maximum.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param maxImpl the StorelessUnivariateStatistic instance to use + * for computing the maximum + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setMaxImpl(StorelessUnivariateStatistic[] maxImpl) + throws MathIllegalStateException, DimensionMismatchException{ + setImpl(maxImpl, this.maxImpl); + } + + /** + * Returns the currently configured sum of logs implementation + * + * @return the StorelessUnivariateStatistic implementing the log sum + */ + public StorelessUnivariateStatistic[] getSumLogImpl() { + return sumLogImpl.clone(); + } + + /** + * <p>Sets the implementation for the sum of logs.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param sumLogImpl the StorelessUnivariateStatistic instance to use + * for computing the log sum + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl) + throws MathIllegalStateException, DimensionMismatchException{ + setImpl(sumLogImpl, this.sumLogImpl); + } + + /** + * Returns the currently configured geometric mean implementation + * + * @return the StorelessUnivariateStatistic implementing the geometric mean + */ + public StorelessUnivariateStatistic[] getGeoMeanImpl() { + return geoMeanImpl.clone(); + } + + /** + * <p>Sets the implementation for the geometric mean.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param geoMeanImpl the StorelessUnivariateStatistic instance to use + * for computing the geometric mean + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl) + throws MathIllegalStateException, DimensionMismatchException { + setImpl(geoMeanImpl, this.geoMeanImpl); + } + + /** + * Returns the currently configured mean implementation + * + * @return the StorelessUnivariateStatistic implementing the mean + */ + public StorelessUnivariateStatistic[] getMeanImpl() { + return meanImpl.clone(); + } + + /** + * <p>Sets the implementation for the mean.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #addValue(double[]) addValue} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param meanImpl the StorelessUnivariateStatistic instance to use + * for computing the mean + * @throws DimensionMismatchException if the array dimension + * does not match the one used at construction + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setMeanImpl(StorelessUnivariateStatistic[] meanImpl) + throws MathIllegalStateException, DimensionMismatchException{ + setImpl(meanImpl, this.meanImpl); + } + + /** + * Throws MathIllegalStateException if the statistic is not empty. + * @throws MathIllegalStateException if n > 0. + */ + private void checkEmpty() throws MathIllegalStateException { + if (n > 0) { + throw new MathIllegalStateException( + LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n); + } + } + + /** + * Throws DimensionMismatchException if dimension != k. + * @param dimension dimension to check + * @throws DimensionMismatchException if dimension != k + */ + private void checkDimension(int dimension) throws DimensionMismatchException { + if (dimension != k) { + throw new DimensionMismatchException(dimension, k); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalMultivariateSummary.java b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalMultivariateSummary.java new file mode 100644 index 0000000..bfe4deb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalMultivariateSummary.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.linear.RealMatrix; + +/** + * Reporting interface for basic multivariate statistics. + * + * @since 1.2 + */ +public interface StatisticalMultivariateSummary { + + /** + * Returns the dimension of the data + * @return The dimension of the data + */ + int getDimension(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * mean of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component means + */ + double[] getMean(); + + /** + * Returns the covariance of the available values. + * @return The covariance, null if no multivariate sample + * have been added or a zeroed matrix for a single value set. + */ + RealMatrix getCovariance(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * standard deviation of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component standard deviations + */ + double[] getStandardDeviation(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * maximum of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component maxima + */ + double[] getMax(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * minimum of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component minima + */ + double[] getMin(); + + /** + * Returns the number of available values + * @return The number of available values + */ + long getN(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * geometric mean of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component geometric means + */ + double[] getGeometricMean(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * sum of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component sums + */ + double[] getSum(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * sum of squares of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component sums of squares + */ + double[] getSumSq(); + + /** + * Returns an array whose i<sup>th</sup> entry is the + * sum of logs of the i<sup>th</sup> entries of the arrays + * that correspond to each multivariate sample + * + * @return the array of component log sums + */ + double[] getSumLog(); + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummary.java b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummary.java new file mode 100644 index 0000000..2f310ac --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummary.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +/** + * Reporting interface for basic univariate statistics. + * + */ +public interface StatisticalSummary { + + /** + * Returns the <a href="http://www.xycoon.com/arithmetic_mean.htm"> + * arithmetic mean </a> of the available values + * @return The mean or Double.NaN if no values have been added. + */ + double getMean(); + /** + * Returns the variance of the available values. + * @return The variance, Double.NaN if no values have been added + * or 0.0 for a single value set. + */ + double getVariance(); + /** + * Returns the standard deviation of the available values. + * @return The standard deviation, Double.NaN if no values have been added + * or 0.0 for a single value set. + */ + double getStandardDeviation(); + /** + * Returns the maximum of the available values + * @return The max or Double.NaN if no values have been added. + */ + double getMax(); + /** + * Returns the minimum of the available values + * @return The min or Double.NaN if no values have been added. + */ + double getMin(); + /** + * Returns the number of available values + * @return The number of available values + */ + long getN(); + /** + * Returns the sum of the values that have been added to Univariate. + * @return The sum or Double.NaN if no values have been added + */ + double getSum(); + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummaryValues.java b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummaryValues.java new file mode 100644 index 0000000..e216e9b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/StatisticalSummaryValues.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import java.io.Serializable; + +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.Precision; + +/** + * Value object representing the results of a univariate statistical summary. + * + */ +public class StatisticalSummaryValues implements Serializable, + StatisticalSummary { + + /** Serialization id */ + private static final long serialVersionUID = -5108854841843722536L; + + /** The sample mean */ + private final double mean; + + /** The sample variance */ + private final double variance; + + /** The number of observations in the sample */ + private final long n; + + /** The maximum value */ + private final double max; + + /** The minimum value */ + private final double min; + + /** The sum of the sample values */ + private final double sum; + + /** + * Constructor + * + * @param mean the sample mean + * @param variance the sample variance + * @param n the number of observations in the sample + * @param max the maximum value + * @param min the minimum value + * @param sum the sum of the values + */ + public StatisticalSummaryValues(double mean, double variance, long n, + double max, double min, double sum) { + super(); + this.mean = mean; + this.variance = variance; + this.n = n; + this.max = max; + this.min = min; + this.sum = sum; + } + + /** + * @return Returns the max. + */ + public double getMax() { + return max; + } + + /** + * @return Returns the mean. + */ + public double getMean() { + return mean; + } + + /** + * @return Returns the min. + */ + public double getMin() { + return min; + } + + /** + * @return Returns the number of values. + */ + public long getN() { + return n; + } + + /** + * @return Returns the sum. + */ + public double getSum() { + return sum; + } + + /** + * @return Returns the standard deviation + */ + public double getStandardDeviation() { + return FastMath.sqrt(variance); + } + + /** + * @return Returns the variance. + */ + public double getVariance() { + return variance; + } + + /** + * Returns true iff <code>object</code> is a + * <code>StatisticalSummaryValues</code> instance and all statistics have + * the same values as this. + * + * @param object the object to test equality against. + * @return true if object equals this + */ + @Override + public boolean equals(Object object) { + if (object == this ) { + return true; + } + if (object instanceof StatisticalSummaryValues == false) { + return false; + } + StatisticalSummaryValues stat = (StatisticalSummaryValues) object; + return Precision.equalsIncludingNaN(stat.getMax(), getMax()) && + Precision.equalsIncludingNaN(stat.getMean(), getMean()) && + Precision.equalsIncludingNaN(stat.getMin(), getMin()) && + Precision.equalsIncludingNaN(stat.getN(), getN()) && + Precision.equalsIncludingNaN(stat.getSum(), getSum()) && + Precision.equalsIncludingNaN(stat.getVariance(), getVariance()); + } + + /** + * Returns hash code based on values of statistics + * + * @return hash code + */ + @Override + public int hashCode() { + int result = 31 + MathUtils.hash(getMax()); + result = result * 31 + MathUtils.hash(getMean()); + result = result * 31 + MathUtils.hash(getMin()); + result = result * 31 + MathUtils.hash(getN()); + result = result * 31 + MathUtils.hash(getSum()); + result = result * 31 + MathUtils.hash(getVariance()); + return result; + } + + /** + * Generates a text report displaying values of statistics. + * Each statistic is displayed on a separate line. + * + * @return String with line feeds displaying statistics + */ + @Override + public String toString() { + StringBuffer outBuffer = new StringBuffer(); + String endl = "\n"; + outBuffer.append("StatisticalSummaryValues:").append(endl); + outBuffer.append("n: ").append(getN()).append(endl); + outBuffer.append("min: ").append(getMin()).append(endl); + outBuffer.append("max: ").append(getMax()).append(endl); + outBuffer.append("mean: ").append(getMean()).append(endl); + outBuffer.append("std dev: ").append(getStandardDeviation()) + .append(endl); + outBuffer.append("variance: ").append(getVariance()).append(endl); + outBuffer.append("sum: ").append(getSum()).append(endl); + return outBuffer.toString(); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/StorelessUnivariateStatistic.java b/src/main/java/org/apache/commons/math3/stat/descriptive/StorelessUnivariateStatistic.java new file mode 100644 index 0000000..e1c2464 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/StorelessUnivariateStatistic.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; + +/** + * Extends the definition of {@link UnivariateStatistic} with + * {@link #increment} and {@link #incrementAll(double[])} methods for adding + * values and updating internal state. + * <p> + * This interface is designed to be used for calculating statistics that can be + * computed in one pass through the data without storing the full array of + * sample values.</p> + * + */ +public interface StorelessUnivariateStatistic extends UnivariateStatistic { + + /** + * Updates the internal state of the statistic to reflect the addition of the new value. + * @param d the new value. + */ + void increment(double d); + + /** + * Updates the internal state of the statistic to reflect addition of + * all values in the values array. Does not clear the statistic first -- + * i.e., the values are added <strong>incrementally</strong> to the dataset. + * + * @param values array holding the new values to add + * @throws MathIllegalArgumentException if the array is null + */ + void incrementAll(double[] values) throws MathIllegalArgumentException; + + /** + * Updates the internal state of the statistic to reflect addition of + * the values in the designated portion of the values array. Does not + * clear the statistic first -- i.e., the values are added + * <strong>incrementally</strong> to the dataset. + * + * @param values array holding the new values to add + * @param start the array index of the first value to add + * @param length the number of elements to add + * @throws MathIllegalArgumentException if the array is null or the index + */ + void incrementAll(double[] values, int start, int length) throws MathIllegalArgumentException; + + /** + * Returns the current value of the Statistic. + * @return value of the statistic, <code>Double.NaN</code> if it + * has been cleared or just instantiated. + */ + double getResult(); + + /** + * Returns the number of values that have been added. + * @return the number of values. + */ + long getN(); + + /** + * Clears the internal state of the Statistic + */ + void clear(); + + /** + * Returns a copy of the statistic with the same internal state. + * + * @return a copy of the statistic + */ + StorelessUnivariateStatistic copy(); + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/SummaryStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/SummaryStatistics.java new file mode 100644 index 0000000..62fee80 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/SummaryStatistics.java @@ -0,0 +1,765 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.moment.GeometricMean; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.SecondMoment; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.stat.descriptive.rank.Max; +import org.apache.commons.math3.stat.descriptive.rank.Min; +import org.apache.commons.math3.stat.descriptive.summary.Sum; +import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs; +import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.Precision; +import org.apache.commons.math3.util.FastMath; + +/** + * <p> + * Computes summary statistics for a stream of data values added using the + * {@link #addValue(double) addValue} method. The data values are not stored in + * memory, so this class can be used to compute statistics for very large data + * streams. + * </p> + * <p> + * The {@link StorelessUnivariateStatistic} instances used to maintain summary + * state and compute statistics are configurable via setters. For example, the + * default implementation for the variance can be overridden by calling + * {@link #setVarianceImpl(StorelessUnivariateStatistic)}. Actual parameters to + * these methods must implement the {@link StorelessUnivariateStatistic} + * interface and configuration must be completed before <code>addValue</code> + * is called. No configuration is necessary to use the default, commons-math + * provided implementations. + * </p> + * <p> + * Note: This class is not thread-safe. Use + * {@link SynchronizedSummaryStatistics} if concurrent access from multiple + * threads is required. + * </p> + */ +public class SummaryStatistics implements StatisticalSummary, Serializable { + + /** Serialization UID */ + private static final long serialVersionUID = -2021321786743555871L; + + /** count of values that have been added */ + private long n = 0; + + /** SecondMoment is used to compute the mean and variance */ + private SecondMoment secondMoment = new SecondMoment(); + + /** sum of values that have been added */ + private Sum sum = new Sum(); + + /** sum of the square of each value that has been added */ + private SumOfSquares sumsq = new SumOfSquares(); + + /** min of values that have been added */ + private Min min = new Min(); + + /** max of values that have been added */ + private Max max = new Max(); + + /** sumLog of values that have been added */ + private SumOfLogs sumLog = new SumOfLogs(); + + /** geoMean of values that have been added */ + private GeometricMean geoMean = new GeometricMean(sumLog); + + /** mean of values that have been added */ + private Mean mean = new Mean(secondMoment); + + /** variance of values that have been added */ + private Variance variance = new Variance(secondMoment); + + /** Sum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic sumImpl = sum; + + /** Sum of squares statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic sumsqImpl = sumsq; + + /** Minimum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic minImpl = min; + + /** Maximum statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic maxImpl = max; + + /** Sum of log statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic sumLogImpl = sumLog; + + /** Geometric mean statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic geoMeanImpl = geoMean; + + /** Mean statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic meanImpl = mean; + + /** Variance statistic implementation - can be reset by setter. */ + private StorelessUnivariateStatistic varianceImpl = variance; + + /** + * Construct a SummaryStatistics instance + */ + public SummaryStatistics() { + } + + /** + * A copy constructor. Creates a deep-copy of the {@code original}. + * + * @param original the {@code SummaryStatistics} instance to copy + * @throws NullArgumentException if original is null + */ + public SummaryStatistics(SummaryStatistics original) throws NullArgumentException { + copy(original, this); + } + + /** + * Return a {@link StatisticalSummaryValues} instance reporting current + * statistics. + * @return Current values of statistics + */ + public StatisticalSummary getSummary() { + return new StatisticalSummaryValues(getMean(), getVariance(), getN(), + getMax(), getMin(), getSum()); + } + + /** + * Add a value to the data + * @param value the value to add + */ + public void addValue(double value) { + sumImpl.increment(value); + sumsqImpl.increment(value); + minImpl.increment(value); + maxImpl.increment(value); + sumLogImpl.increment(value); + secondMoment.increment(value); + // If mean, variance or geomean have been overridden, + // need to increment these + if (meanImpl != mean) { + meanImpl.increment(value); + } + if (varianceImpl != variance) { + varianceImpl.increment(value); + } + if (geoMeanImpl != geoMean) { + geoMeanImpl.increment(value); + } + n++; + } + + /** + * Returns the number of available values + * @return The number of available values + */ + public long getN() { + return n; + } + + /** + * Returns the sum of the values that have been added + * @return The sum or <code>Double.NaN</code> if no values have been added + */ + public double getSum() { + return sumImpl.getResult(); + } + + /** + * Returns the sum of the squares of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return The sum of squares + */ + public double getSumsq() { + return sumsqImpl.getResult(); + } + + /** + * Returns the mean of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the mean + */ + public double getMean() { + return meanImpl.getResult(); + } + + /** + * Returns the standard deviation of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the standard deviation + */ + public double getStandardDeviation() { + double stdDev = Double.NaN; + if (getN() > 0) { + if (getN() > 1) { + stdDev = FastMath.sqrt(getVariance()); + } else { + stdDev = 0.0; + } + } + return stdDev; + } + + /** + * Returns the quadratic mean, a.k.a. + * <a href="http://mathworld.wolfram.com/Root-Mean-Square.html"> + * root-mean-square</a> of the available values + * @return The quadratic mean or {@code Double.NaN} if no values + * have been added. + */ + public double getQuadraticMean() { + final long size = getN(); + return size > 0 ? FastMath.sqrt(getSumsq() / size) : Double.NaN; + } + + /** + * Returns the (sample) variance of the available values. + * + * <p>This method returns the bias-corrected sample variance (using {@code n - 1} in + * the denominator). Use {@link #getPopulationVariance()} for the non-bias-corrected + * population variance.</p> + * + * <p>Double.NaN is returned if no values have been added.</p> + * + * @return the variance + */ + public double getVariance() { + return varianceImpl.getResult(); + } + + /** + * Returns the <a href="http://en.wikibooks.org/wiki/Statistics/Summary/Variance"> + * population variance</a> of the values that have been added. + * + * <p>Double.NaN is returned if no values have been added.</p> + * + * @return the population variance + */ + public double getPopulationVariance() { + Variance populationVariance = new Variance(secondMoment); + populationVariance.setBiasCorrected(false); + return populationVariance.getResult(); + } + + /** + * Returns the maximum of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the maximum + */ + public double getMax() { + return maxImpl.getResult(); + } + + /** + * Returns the minimum of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the minimum + */ + public double getMin() { + return minImpl.getResult(); + } + + /** + * Returns the geometric mean of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the geometric mean + */ + public double getGeometricMean() { + return geoMeanImpl.getResult(); + } + + /** + * Returns the sum of the logs of the values that have been added. + * <p> + * Double.NaN is returned if no values have been added. + * </p> + * @return the sum of logs + * @since 1.2 + */ + public double getSumOfLogs() { + return sumLogImpl.getResult(); + } + + /** + * Returns a statistic related to the Second Central Moment. Specifically, + * what is returned is the sum of squared deviations from the sample mean + * among the values that have been added. + * <p> + * Returns <code>Double.NaN</code> if no data values have been added and + * returns <code>0</code> if there is just one value in the data set.</p> + * <p> + * @return second central moment statistic + * @since 2.0 + */ + public double getSecondMoment() { + return secondMoment.getResult(); + } + + /** + * Generates a text report displaying summary statistics from values that + * have been added. + * @return String with line feeds displaying statistics + * @since 1.2 + */ + @Override + public String toString() { + StringBuilder outBuffer = new StringBuilder(); + String endl = "\n"; + outBuffer.append("SummaryStatistics:").append(endl); + outBuffer.append("n: ").append(getN()).append(endl); + outBuffer.append("min: ").append(getMin()).append(endl); + outBuffer.append("max: ").append(getMax()).append(endl); + outBuffer.append("sum: ").append(getSum()).append(endl); + outBuffer.append("mean: ").append(getMean()).append(endl); + outBuffer.append("geometric mean: ").append(getGeometricMean()) + .append(endl); + outBuffer.append("variance: ").append(getVariance()).append(endl); + outBuffer.append("population variance: ").append(getPopulationVariance()).append(endl); + outBuffer.append("second moment: ").append(getSecondMoment()).append(endl); + outBuffer.append("sum of squares: ").append(getSumsq()).append(endl); + outBuffer.append("standard deviation: ").append(getStandardDeviation()) + .append(endl); + outBuffer.append("sum of logs: ").append(getSumOfLogs()).append(endl); + return outBuffer.toString(); + } + + /** + * Resets all statistics and storage + */ + public void clear() { + this.n = 0; + minImpl.clear(); + maxImpl.clear(); + sumImpl.clear(); + sumLogImpl.clear(); + sumsqImpl.clear(); + geoMeanImpl.clear(); + secondMoment.clear(); + if (meanImpl != mean) { + meanImpl.clear(); + } + if (varianceImpl != variance) { + varianceImpl.clear(); + } + } + + /** + * Returns true iff <code>object</code> is a + * <code>SummaryStatistics</code> instance and all statistics have the + * same values as this. + * @param object the object to test equality against. + * @return true if object equals this + */ + @Override + public boolean equals(Object object) { + if (object == this) { + return true; + } + if (object instanceof SummaryStatistics == false) { + return false; + } + SummaryStatistics stat = (SummaryStatistics)object; + return Precision.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) && + Precision.equalsIncludingNaN(stat.getMax(), getMax()) && + Precision.equalsIncludingNaN(stat.getMean(), getMean()) && + Precision.equalsIncludingNaN(stat.getMin(), getMin()) && + Precision.equalsIncludingNaN(stat.getN(), getN()) && + Precision.equalsIncludingNaN(stat.getSum(), getSum()) && + Precision.equalsIncludingNaN(stat.getSumsq(), getSumsq()) && + Precision.equalsIncludingNaN(stat.getVariance(), getVariance()); + } + + /** + * Returns hash code based on values of statistics + * @return hash code + */ + @Override + public int hashCode() { + int result = 31 + MathUtils.hash(getGeometricMean()); + result = result * 31 + MathUtils.hash(getGeometricMean()); + result = result * 31 + MathUtils.hash(getMax()); + result = result * 31 + MathUtils.hash(getMean()); + result = result * 31 + MathUtils.hash(getMin()); + result = result * 31 + MathUtils.hash(getN()); + result = result * 31 + MathUtils.hash(getSum()); + result = result * 31 + MathUtils.hash(getSumsq()); + result = result * 31 + MathUtils.hash(getVariance()); + return result; + } + + // Getters and setters for statistics implementations + /** + * Returns the currently configured Sum implementation + * @return the StorelessUnivariateStatistic implementing the sum + * @since 1.2 + */ + public StorelessUnivariateStatistic getSumImpl() { + return sumImpl; + } + + /** + * <p> + * Sets the implementation for the Sum. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param sumImpl the StorelessUnivariateStatistic instance to use for + * computing the Sum + * @throws MathIllegalStateException if data has already been added (i.e if n >0) + * @since 1.2 + */ + public void setSumImpl(StorelessUnivariateStatistic sumImpl) + throws MathIllegalStateException { + checkEmpty(); + this.sumImpl = sumImpl; + } + + /** + * Returns the currently configured sum of squares implementation + * @return the StorelessUnivariateStatistic implementing the sum of squares + * @since 1.2 + */ + public StorelessUnivariateStatistic getSumsqImpl() { + return sumsqImpl; + } + + /** + * <p> + * Sets the implementation for the sum of squares. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param sumsqImpl the StorelessUnivariateStatistic instance to use for + * computing the sum of squares + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setSumsqImpl(StorelessUnivariateStatistic sumsqImpl) + throws MathIllegalStateException { + checkEmpty(); + this.sumsqImpl = sumsqImpl; + } + + /** + * Returns the currently configured minimum implementation + * @return the StorelessUnivariateStatistic implementing the minimum + * @since 1.2 + */ + public StorelessUnivariateStatistic getMinImpl() { + return minImpl; + } + + /** + * <p> + * Sets the implementation for the minimum. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param minImpl the StorelessUnivariateStatistic instance to use for + * computing the minimum + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setMinImpl(StorelessUnivariateStatistic minImpl) + throws MathIllegalStateException { + checkEmpty(); + this.minImpl = minImpl; + } + + /** + * Returns the currently configured maximum implementation + * @return the StorelessUnivariateStatistic implementing the maximum + * @since 1.2 + */ + public StorelessUnivariateStatistic getMaxImpl() { + return maxImpl; + } + + /** + * <p> + * Sets the implementation for the maximum. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param maxImpl the StorelessUnivariateStatistic instance to use for + * computing the maximum + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setMaxImpl(StorelessUnivariateStatistic maxImpl) + throws MathIllegalStateException { + checkEmpty(); + this.maxImpl = maxImpl; + } + + /** + * Returns the currently configured sum of logs implementation + * @return the StorelessUnivariateStatistic implementing the log sum + * @since 1.2 + */ + public StorelessUnivariateStatistic getSumLogImpl() { + return sumLogImpl; + } + + /** + * <p> + * Sets the implementation for the sum of logs. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param sumLogImpl the StorelessUnivariateStatistic instance to use for + * computing the log sum + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setSumLogImpl(StorelessUnivariateStatistic sumLogImpl) + throws MathIllegalStateException { + checkEmpty(); + this.sumLogImpl = sumLogImpl; + geoMean.setSumLogImpl(sumLogImpl); + } + + /** + * Returns the currently configured geometric mean implementation + * @return the StorelessUnivariateStatistic implementing the geometric mean + * @since 1.2 + */ + public StorelessUnivariateStatistic getGeoMeanImpl() { + return geoMeanImpl; + } + + /** + * <p> + * Sets the implementation for the geometric mean. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param geoMeanImpl the StorelessUnivariateStatistic instance to use for + * computing the geometric mean + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setGeoMeanImpl(StorelessUnivariateStatistic geoMeanImpl) + throws MathIllegalStateException { + checkEmpty(); + this.geoMeanImpl = geoMeanImpl; + } + + /** + * Returns the currently configured mean implementation + * @return the StorelessUnivariateStatistic implementing the mean + * @since 1.2 + */ + public StorelessUnivariateStatistic getMeanImpl() { + return meanImpl; + } + + /** + * <p> + * Sets the implementation for the mean. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param meanImpl the StorelessUnivariateStatistic instance to use for + * computing the mean + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setMeanImpl(StorelessUnivariateStatistic meanImpl) + throws MathIllegalStateException { + checkEmpty(); + this.meanImpl = meanImpl; + } + + /** + * Returns the currently configured variance implementation + * @return the StorelessUnivariateStatistic implementing the variance + * @since 1.2 + */ + public StorelessUnivariateStatistic getVarianceImpl() { + return varianceImpl; + } + + /** + * <p> + * Sets the implementation for the variance. + * </p> + * <p> + * This method cannot be activated after data has been added - i.e., + * after {@link #addValue(double) addValue} has been used to add data. + * If it is activated after data has been added, an IllegalStateException + * will be thrown. + * </p> + * @param varianceImpl the StorelessUnivariateStatistic instance to use for + * computing the variance + * @throws MathIllegalStateException if data has already been added (i.e if n > 0) + * @since 1.2 + */ + public void setVarianceImpl(StorelessUnivariateStatistic varianceImpl) + throws MathIllegalStateException { + checkEmpty(); + this.varianceImpl = varianceImpl; + } + + /** + * Throws IllegalStateException if n > 0. + * @throws MathIllegalStateException if data has been added + */ + private void checkEmpty() throws MathIllegalStateException { + if (n > 0) { + throw new MathIllegalStateException( + LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n); + } + } + + /** + * Returns a copy of this SummaryStatistics instance with the same internal state. + * + * @return a copy of this + */ + public SummaryStatistics copy() { + SummaryStatistics result = new SummaryStatistics(); + // No try-catch or advertised exception because arguments are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source SummaryStatistics to copy + * @param dest SummaryStatistics to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SummaryStatistics source, SummaryStatistics dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.maxImpl = source.maxImpl.copy(); + dest.minImpl = source.minImpl.copy(); + dest.sumImpl = source.sumImpl.copy(); + dest.sumLogImpl = source.sumLogImpl.copy(); + dest.sumsqImpl = source.sumsqImpl.copy(); + dest.secondMoment = source.secondMoment.copy(); + dest.n = source.n; + + // Keep commons-math supplied statistics with embedded moments in synch + if (source.getVarianceImpl() instanceof Variance) { + dest.varianceImpl = new Variance(dest.secondMoment); + } else { + dest.varianceImpl = source.varianceImpl.copy(); + } + if (source.meanImpl instanceof Mean) { + dest.meanImpl = new Mean(dest.secondMoment); + } else { + dest.meanImpl = source.meanImpl.copy(); + } + if (source.getGeoMeanImpl() instanceof GeometricMean) { + dest.geoMeanImpl = new GeometricMean((SumOfLogs) dest.sumLogImpl); + } else { + dest.geoMeanImpl = source.geoMeanImpl.copy(); + } + + // Make sure that if stat == statImpl in source, same + // holds in dest; otherwise copy stat + if (source.geoMean == source.geoMeanImpl) { + dest.geoMean = (GeometricMean) dest.geoMeanImpl; + } else { + GeometricMean.copy(source.geoMean, dest.geoMean); + } + if (source.max == source.maxImpl) { + dest.max = (Max) dest.maxImpl; + } else { + Max.copy(source.max, dest.max); + } + if (source.mean == source.meanImpl) { + dest.mean = (Mean) dest.meanImpl; + } else { + Mean.copy(source.mean, dest.mean); + } + if (source.min == source.minImpl) { + dest.min = (Min) dest.minImpl; + } else { + Min.copy(source.min, dest.min); + } + if (source.sum == source.sumImpl) { + dest.sum = (Sum) dest.sumImpl; + } else { + Sum.copy(source.sum, dest.sum); + } + if (source.variance == source.varianceImpl) { + dest.variance = (Variance) dest.varianceImpl; + } else { + Variance.copy(source.variance, dest.variance); + } + if (source.sumLog == source.sumLogImpl) { + dest.sumLog = (SumOfLogs) dest.sumLogImpl; + } else { + SumOfLogs.copy(source.sumLog, dest.sumLog); + } + if (source.sumsq == source.sumsqImpl) { + dest.sumsq = (SumOfSquares) dest.sumsqImpl; + } else { + SumOfSquares.copy(source.sumsq, dest.sumsq); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedDescriptiveStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedDescriptiveStatistics.java new file mode 100644 index 0000000..270e4aa --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedDescriptiveStatistics.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + +/** + * Implementation of + * {@link org.apache.commons.math3.stat.descriptive.DescriptiveStatistics} that + * is safe to use in a multithreaded environment. Multiple threads can safely + * operate on a single instance without causing runtime exceptions due to race + * conditions. In effect, this implementation makes modification and access + * methods atomic operations for a single instance. That is to say, as one + * thread is computing a statistic from the instance, no other thread can modify + * the instance nor compute another statistic. + * + * @since 1.2 + */ +public class SynchronizedDescriptiveStatistics extends DescriptiveStatistics { + + /** Serialization UID */ + private static final long serialVersionUID = 1L; + + /** + * Construct an instance with infinite window + */ + public SynchronizedDescriptiveStatistics() { + // no try-catch or advertized IAE because arg is valid + this(INFINITE_WINDOW); + } + + /** + * Construct an instance with finite window + * @param window the finite window size. + * @throws MathIllegalArgumentException if window size is less than 1 but + * not equal to {@link #INFINITE_WINDOW} + */ + public SynchronizedDescriptiveStatistics(int window) throws MathIllegalArgumentException { + super(window); + } + + /** + * A copy constructor. Creates a deep-copy of the {@code original}. + * + * @param original the {@code SynchronizedDescriptiveStatistics} instance to copy + * @throws NullArgumentException if original is null + */ + public SynchronizedDescriptiveStatistics(SynchronizedDescriptiveStatistics original) + throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void addValue(double v) { + super.addValue(v); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double apply(UnivariateStatistic stat) { + return super.apply(stat); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void clear() { + super.clear(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getElement(int index) { + return super.getElement(index); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized long getN() { + return super.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getStandardDeviation() { + return super.getStandardDeviation(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getQuadraticMean() { + return super.getQuadraticMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getValues() { + return super.getValues(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized int getWindowSize() { + return super.getWindowSize(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setWindowSize(int windowSize) throws MathIllegalArgumentException { + super.setWindowSize(windowSize); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized String toString() { + return super.toString(); + } + + /** + * Returns a copy of this SynchronizedDescriptiveStatistics instance with the + * same internal state. + * + * @return a copy of this + */ + @Override + public synchronized SynchronizedDescriptiveStatistics copy() { + SynchronizedDescriptiveStatistics result = + new SynchronizedDescriptiveStatistics(); + // No try-catch or advertised exception because arguments are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * <p>Acquires synchronization lock on source, then dest before copying.</p> + * + * @param source SynchronizedDescriptiveStatistics to copy + * @param dest SynchronizedDescriptiveStatistics to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SynchronizedDescriptiveStatistics source, + SynchronizedDescriptiveStatistics dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + synchronized (source) { + synchronized (dest) { + DescriptiveStatistics.copy(source, dest); + } + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedMultivariateSummaryStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedMultivariateSummaryStatistics.java new file mode 100644 index 0000000..889eb3a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedMultivariateSummaryStatistics.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.linear.RealMatrix; + +/** + * Implementation of + * {@link org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics} that + * is safe to use in a multithreaded environment. Multiple threads can safely + * operate on a single instance without causing runtime exceptions due to race + * conditions. In effect, this implementation makes modification and access + * methods atomic operations for a single instance. That is to say, as one + * thread is computing a statistic from the instance, no other thread can modify + * the instance nor compute another statistic. + * @since 1.2 + */ +public class SynchronizedMultivariateSummaryStatistics + extends MultivariateSummaryStatistics { + + /** Serialization UID */ + private static final long serialVersionUID = 7099834153347155363L; + + /** + * Construct a SynchronizedMultivariateSummaryStatistics instance + * @param k dimension of the data + * @param isCovarianceBiasCorrected if true, the unbiased sample + * covariance is computed, otherwise the biased population covariance + * is computed + */ + public SynchronizedMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) { + super(k, isCovarianceBiasCorrected); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void addValue(double[] value) throws DimensionMismatchException { + super.addValue(value); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized int getDimension() { + return super.getDimension(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized long getN() { + return super.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getSum() { + return super.getSum(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getSumSq() { + return super.getSumSq(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getSumLog() { + return super.getSumLog(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getMean() { + return super.getMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getStandardDeviation() { + return super.getStandardDeviation(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized RealMatrix getCovariance() { + return super.getCovariance(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getMax() { + return super.getMax(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getMin() { + return super.getMin(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double[] getGeometricMean() { + return super.getGeometricMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized String toString() { + return super.toString(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void clear() { + super.clear(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized boolean equals(Object object) { + return super.equals(object); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized int hashCode() { + return super.hashCode(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getSumImpl() { + return super.getSumImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumImpl(StorelessUnivariateStatistic[] sumImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setSumImpl(sumImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getSumsqImpl() { + return super.getSumsqImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setSumsqImpl(sumsqImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getMinImpl() { + return super.getMinImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMinImpl(StorelessUnivariateStatistic[] minImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setMinImpl(minImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getMaxImpl() { + return super.getMaxImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMaxImpl(StorelessUnivariateStatistic[] maxImpl) + throws DimensionMismatchException, MathIllegalStateException{ + super.setMaxImpl(maxImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getSumLogImpl() { + return super.getSumLogImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setSumLogImpl(sumLogImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getGeoMeanImpl() { + return super.getGeoMeanImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setGeoMeanImpl(geoMeanImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic[] getMeanImpl() { + return super.getMeanImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMeanImpl(StorelessUnivariateStatistic[] meanImpl) + throws DimensionMismatchException, MathIllegalStateException { + super.setMeanImpl(meanImpl); + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedSummaryStatistics.java b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedSummaryStatistics.java new file mode 100644 index 0000000..7eaf9ac --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/SynchronizedSummaryStatistics.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + +/** + * Implementation of + * {@link org.apache.commons.math3.stat.descriptive.SummaryStatistics} that + * is safe to use in a multithreaded environment. Multiple threads can safely + * operate on a single instance without causing runtime exceptions due to race + * conditions. In effect, this implementation makes modification and access + * methods atomic operations for a single instance. That is to say, as one + * thread is computing a statistic from the instance, no other thread can modify + * the instance nor compute another statistic. + * + * @since 1.2 + */ +public class SynchronizedSummaryStatistics extends SummaryStatistics { + + /** Serialization UID */ + private static final long serialVersionUID = 1909861009042253704L; + + /** + * Construct a SynchronizedSummaryStatistics instance + */ + public SynchronizedSummaryStatistics() { + super(); + } + + /** + * A copy constructor. Creates a deep-copy of the {@code original}. + * + * @param original the {@code SynchronizedSummaryStatistics} instance to copy + * @throws NullArgumentException if original is null + */ + public SynchronizedSummaryStatistics(SynchronizedSummaryStatistics original) + throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StatisticalSummary getSummary() { + return super.getSummary(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void addValue(double value) { + super.addValue(value); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized long getN() { + return super.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getSum() { + return super.getSum(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getSumsq() { + return super.getSumsq(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getMean() { + return super.getMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getStandardDeviation() { + return super.getStandardDeviation(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getQuadraticMean() { + return super.getQuadraticMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getVariance() { + return super.getVariance(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getPopulationVariance() { + return super.getPopulationVariance(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getMax() { + return super.getMax(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getMin() { + return super.getMin(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized double getGeometricMean() { + return super.getGeometricMean(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized String toString() { + return super.toString(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void clear() { + super.clear(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized boolean equals(Object object) { + return super.equals(object); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized int hashCode() { + return super.hashCode(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getSumImpl() { + return super.getSumImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumImpl(StorelessUnivariateStatistic sumImpl) + throws MathIllegalStateException { + super.setSumImpl(sumImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getSumsqImpl() { + return super.getSumsqImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumsqImpl(StorelessUnivariateStatistic sumsqImpl) + throws MathIllegalStateException { + super.setSumsqImpl(sumsqImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getMinImpl() { + return super.getMinImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMinImpl(StorelessUnivariateStatistic minImpl) + throws MathIllegalStateException { + super.setMinImpl(minImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getMaxImpl() { + return super.getMaxImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMaxImpl(StorelessUnivariateStatistic maxImpl) + throws MathIllegalStateException { + super.setMaxImpl(maxImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getSumLogImpl() { + return super.getSumLogImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setSumLogImpl(StorelessUnivariateStatistic sumLogImpl) + throws MathIllegalStateException { + super.setSumLogImpl(sumLogImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getGeoMeanImpl() { + return super.getGeoMeanImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setGeoMeanImpl(StorelessUnivariateStatistic geoMeanImpl) + throws MathIllegalStateException { + super.setGeoMeanImpl(geoMeanImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getMeanImpl() { + return super.getMeanImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setMeanImpl(StorelessUnivariateStatistic meanImpl) + throws MathIllegalStateException { + super.setMeanImpl(meanImpl); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized StorelessUnivariateStatistic getVarianceImpl() { + return super.getVarianceImpl(); + } + + /** + * {@inheritDoc} + */ + @Override + public synchronized void setVarianceImpl(StorelessUnivariateStatistic varianceImpl) + throws MathIllegalStateException { + super.setVarianceImpl(varianceImpl); + } + + /** + * Returns a copy of this SynchronizedSummaryStatistics instance with the + * same internal state. + * + * @return a copy of this + */ + @Override + public synchronized SynchronizedSummaryStatistics copy() { + SynchronizedSummaryStatistics result = + new SynchronizedSummaryStatistics(); + // No try-catch or advertised exception because arguments are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * <p>Acquires synchronization lock on source, then dest before copying.</p> + * + * @param source SynchronizedSummaryStatistics to copy + * @param dest SynchronizedSummaryStatistics to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SynchronizedSummaryStatistics source, + SynchronizedSummaryStatistics dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + synchronized (source) { + synchronized (dest) { + SummaryStatistics.copy(source, dest); + } + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/UnivariateStatistic.java b/src/main/java/org/apache/commons/math3/stat/descriptive/UnivariateStatistic.java new file mode 100644 index 0000000..5d6c9fe --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/UnivariateStatistic.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.util.MathArrays; + + +/** + * Base interface implemented by all statistics. + * + */ +public interface UnivariateStatistic extends MathArrays.Function { + /** + * Returns the result of evaluating the statistic over the input array. + * + * @param values input array + * @return the value of the statistic applied to the input array + * @throws MathIllegalArgumentException if values is null + */ + double evaluate(double[] values) throws MathIllegalArgumentException; + + /** + * Returns the result of evaluating the statistic over the specified entries + * in the input array. + * + * @param values the input array + * @param begin the index of the first element to include + * @param length the number of elements to include + * @return the value of the statistic applied to the included array entries + * @throws MathIllegalArgumentException if values is null or the indices are invalid + */ + double evaluate(double[] values, int begin, int length) throws MathIllegalArgumentException; + + /** + * Returns a copy of the statistic with the same internal state. + * + * @return a copy of the statistic + */ + UnivariateStatistic copy(); +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/WeightedEvaluation.java b/src/main/java/org/apache/commons/math3/stat/descriptive/WeightedEvaluation.java new file mode 100644 index 0000000..01693dc --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/WeightedEvaluation.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; + +/** + * Weighted evaluation for statistics. + * + * @since 2.1 + */ +public interface WeightedEvaluation { + + /** + * Returns the result of evaluating the statistic over the input array, + * using the supplied weights. + * + * @param values input array + * @param weights array of weights + * @return the value of the weighted statistic applied to the input array + * @throws MathIllegalArgumentException if either array is null, lengths + * do not match, weights contain NaN, negative or infinite values, or + * weights does not include at least on positive value + */ + double evaluate(double[] values, double[] weights) throws MathIllegalArgumentException; + + /** + * Returns the result of evaluating the statistic over the specified entries + * in the input array, using corresponding entries in the supplied weights array. + * + * @param values the input array + * @param weights array of weights + * @param begin the index of the first element to include + * @param length the number of elements to include + * @return the value of the weighted statistic applied to the included array entries + * @throws MathIllegalArgumentException if either array is null, lengths + * do not match, indices are invalid, weights contain NaN, negative or + * infinite values, or weights does not include at least on positive value + */ + double evaluate(double[] values, double[] weights, int begin, int length) + throws MathIllegalArgumentException; + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FirstMoment.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FirstMoment.java new file mode 100644 index 0000000..c153724 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FirstMoment.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes the first moment (arithmetic mean). Uses the definitional formula: + * <p> + * mean = sum(x_i) / n </p> + * <p> + * where <code>n</code> is the number of observations. </p> + * <p> + * To limit numeric errors, the value of the statistic is computed using the + * following recursive updating algorithm: </p> + * <p> + * <ol> + * <li>Initialize <code>m = </code> the first value</li> + * <li>For each additional value, update using <br> + * <code>m = m + (new value - m) / (number of observations)</code></li> + * </ol></p> + * <p> + * Returns <code>Double.NaN</code> if the dataset is empty. Note that + * Double.NaN may also be returned if the input includes NaN and / or infinite + * values.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +class FirstMoment extends AbstractStorelessUnivariateStatistic + implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 6112755307178490473L; + + + /** Count of values that have been added */ + protected long n; + + /** First moment of values that have been added */ + protected double m1; + + /** + * Deviation of most recently added value from previous first moment. + * Retained to prevent repeated computation in higher order moments. + */ + protected double dev; + + /** + * Deviation of most recently added value from previous first moment, + * normalized by previous sample size. Retained to prevent repeated + * computation in higher order moments + */ + protected double nDev; + + /** + * Create a FirstMoment instance + */ + FirstMoment() { + n = 0; + m1 = Double.NaN; + dev = Double.NaN; + nDev = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code FirstMoment} identical + * to the {@code original} + * + * @param original the {@code FirstMoment} instance to copy + * @throws NullArgumentException if original is null + */ + FirstMoment(FirstMoment original) throws NullArgumentException { + super(); + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (n == 0) { + m1 = 0.0; + } + n++; + double n0 = n; + dev = d - m1; + nDev = dev / n0; + m1 += nDev; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + m1 = Double.NaN; + n = 0; + dev = Double.NaN; + nDev = Double.NaN; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return m1; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * {@inheritDoc} + */ + @Override + public FirstMoment copy() { + FirstMoment result = new FirstMoment(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source FirstMoment to copy + * @param dest FirstMoment to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(FirstMoment source, FirstMoment dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.m1 = source.m1; + dest.dev = source.dev; + dest.nDev = source.nDev; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FourthMoment.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FourthMoment.java new file mode 100644 index 0000000..0c199d8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/FourthMoment.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes a statistic related to the Fourth Central Moment. Specifically, + * what is computed is the sum of + * <p> + * (x_i - xbar) ^ 4, </p> + * <p> + * where the x_i are the + * sample observations and xbar is the sample mean. </p> + * <p> + * The following recursive updating formula is used: </p> + * <p> + * Let <ul> + * <li> dev = (current obs - previous mean) </li> + * <li> m2 = previous value of {@link SecondMoment} </li> + * <li> m2 = previous value of {@link ThirdMoment} </li> + * <li> n = number of observations (including current obs) </li> + * </ul> + * Then </p> + * <p> + * new value = old value - 4 * (dev/n) * m3 + 6 * (dev/n)^2 * m2 + <br> + * [n^2 - 3 * (n-1)] * dev^4 * (n-1) / n^3 </p> + * <p> + * Returns <code>Double.NaN</code> if no data values have been added and + * returns <code>0</code> if there is just one value in the data set. Note that + * Double.NaN may also be returned if the input includes NaN and / or infinite + * values. </p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally. </p> + * + */ +class FourthMoment extends ThirdMoment implements Serializable{ + + /** Serializable version identifier */ + private static final long serialVersionUID = 4763990447117157611L; + + /** fourth moment of values that have been added */ + private double m4; + + /** + * Create a FourthMoment instance + */ + FourthMoment() { + super(); + m4 = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code FourthMoment} identical + * to the {@code original} + * + * @param original the {@code FourthMoment} instance to copy + * @throws NullArgumentException if original is null + */ + FourthMoment(FourthMoment original) throws NullArgumentException { + super(); + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (n < 1) { + m4 = 0.0; + m3 = 0.0; + m2 = 0.0; + m1 = 0.0; + } + + double prevM3 = m3; + double prevM2 = m2; + + super.increment(d); + + double n0 = n; + + m4 = m4 - 4.0 * nDev * prevM3 + 6.0 * nDevSq * prevM2 + + ((n0 * n0) - 3 * (n0 -1)) * (nDevSq * nDevSq * (n0 - 1) * n0); + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return m4; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + super.clear(); + m4 = Double.NaN; + } + + /** + * {@inheritDoc} + */ + @Override + public FourthMoment copy() { + FourthMoment result = new FourthMoment(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source FourthMoment to copy + * @param dest FourthMoment to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(FourthMoment source, FourthMoment dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + ThirdMoment.copy(source, dest); + dest.m4 = source.m4; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/GeometricMean.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/GeometricMean.java new file mode 100644 index 0000000..bfee9df --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/GeometricMean.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.StorelessUnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the <a href="http://www.xycoon.com/geometric_mean.htm"> + * geometric mean </a> of the available values. + * <p> + * Uses a {@link SumOfLogs} instance to compute sum of logs and returns + * <code> exp( 1/n (sum of logs) ).</code> Therefore, </p> + * <ul> + * <li>If any of values are < 0, the result is <code>NaN.</code></li> + * <li>If all values are non-negative and less than + * <code>Double.POSITIVE_INFINITY</code>, but at least one value is 0, the + * result is <code>0.</code></li> + * <li>If both <code>Double.POSITIVE_INFINITY</code> and + * <code>Double.NEGATIVE_INFINITY</code> are among the values, the result is + * <code>NaN.</code></li> + * </ul> </p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + * + */ +public class GeometricMean extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -8178734905303459453L; + + /** Wrapped SumOfLogs instance */ + private StorelessUnivariateStatistic sumOfLogs; + + /** + * Create a GeometricMean instance + */ + public GeometricMean() { + sumOfLogs = new SumOfLogs(); + } + + /** + * Copy constructor, creates a new {@code GeometricMean} identical + * to the {@code original} + * + * @param original the {@code GeometricMean} instance to copy + * @throws NullArgumentException if original is null + */ + public GeometricMean(GeometricMean original) throws NullArgumentException { + super(); + copy(original, this); + } + + /** + * Create a GeometricMean instance using the given SumOfLogs instance + * @param sumOfLogs sum of logs instance to use for computation + */ + public GeometricMean(SumOfLogs sumOfLogs) { + this.sumOfLogs = sumOfLogs; + } + + /** + * {@inheritDoc} + */ + @Override + public GeometricMean copy() { + GeometricMean result = new GeometricMean(); + // no try-catch or advertised exception because args guaranteed non-null + copy(this, result); + return result; + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + sumOfLogs.increment(d); + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + if (sumOfLogs.getN() > 0) { + return FastMath.exp(sumOfLogs.getResult() / sumOfLogs.getN()); + } else { + return Double.NaN; + } + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + sumOfLogs.clear(); + } + + /** + * Returns the geometric mean of the entries in the specified portion + * of the input array. + * <p> + * See {@link GeometricMean} for details on the computing algorithm.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if the array is null.</p> + * + * @param values input array containing the values + * @param begin first array element to include + * @param length the number of elements to include + * @return the geometric mean or Double.NaN if length = 0 or + * any of the values are <= 0. + * @throws MathIllegalArgumentException if the input array is null or the array + * index parameters are not valid + */ + @Override + public double evaluate( + final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return FastMath.exp( + sumOfLogs.evaluate(values, begin, length) / length); + } + + /** + * {@inheritDoc} + */ + public long getN() { + return sumOfLogs.getN(); + } + + /** + * <p>Sets the implementation for the sum of logs.</p> + * <p>This method must be activated before any data has been added - i.e., + * before {@link #increment(double) increment} has been used to add data; + * otherwise an IllegalStateException will be thrown.</p> + * + * @param sumLogImpl the StorelessUnivariateStatistic instance to use + * for computing the log sum + * @throws MathIllegalStateException if data has already been added + * (i.e if n > 0) + */ + public void setSumLogImpl(StorelessUnivariateStatistic sumLogImpl) + throws MathIllegalStateException { + checkEmpty(); + this.sumOfLogs = sumLogImpl; + } + + /** + * Returns the currently configured sum of logs implementation + * + * @return the StorelessUnivariateStatistic implementing the log sum + */ + public StorelessUnivariateStatistic getSumLogImpl() { + return sumOfLogs; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source GeometricMean to copy + * @param dest GeometricMean to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(GeometricMean source, GeometricMean dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.sumOfLogs = source.sumOfLogs.copy(); + } + + + /** + * Throws MathIllegalStateException if n > 0. + * @throws MathIllegalStateException if data has been added to this statistic + */ + private void checkEmpty() throws MathIllegalStateException { + if (getN() > 0) { + throw new MathIllegalStateException( + LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, + getN()); + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Kurtosis.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Kurtosis.java new file mode 100644 index 0000000..be04fbe --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Kurtosis.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + + +/** + * Computes the Kurtosis of the available values. + * <p> + * We use the following (unbiased) formula to define kurtosis:</p> + * <p> + * kurtosis = { [n(n+1) / (n -1)(n - 2)(n-3)] sum[(x_i - mean)^4] / std^4 } - [3(n-1)^2 / (n-2)(n-3)] + * </p><p> + * where n is the number of values, mean is the {@link Mean} and std is the + * {@link StandardDeviation}</p> + * <p> + * Note that this statistic is undefined for n < 4. <code>Double.Nan</code> + * is returned when there is not sufficient data to compute the statistic. + * Note that Double.NaN may also be returned if the input includes NaN + * and / or infinite values.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Kurtosis extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 2784465764798260919L; + + /**Fourth Moment on which this statistic is based */ + protected FourthMoment moment; + + /** + * Determines whether or not this statistic can be incremented or cleared. + * <p> + * Statistics based on (constructed from) external moments cannot + * be incremented or cleared.</p> + */ + protected boolean incMoment; + + /** + * Construct a Kurtosis + */ + public Kurtosis() { + incMoment = true; + moment = new FourthMoment(); + } + + /** + * Construct a Kurtosis from an external moment + * + * @param m4 external Moment + */ + public Kurtosis(final FourthMoment m4) { + incMoment = false; + this.moment = m4; + } + + /** + * Copy constructor, creates a new {@code Kurtosis} identical + * to the {@code original} + * + * @param original the {@code Kurtosis} instance to copy + * @throws NullArgumentException if original is null + */ + public Kurtosis(Kurtosis original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + * <p>Note that when {@link #Kurtosis(FourthMoment)} is used to + * create a Variance, this method does nothing. In that case, the + * FourthMoment should be incremented directly.</p> + */ + @Override + public void increment(final double d) { + if (incMoment) { + moment.increment(d); + } + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + double kurtosis = Double.NaN; + if (moment.getN() > 3) { + double variance = moment.m2 / (moment.n - 1); + if (moment.n <= 3 || variance < 10E-20) { + kurtosis = 0.0; + } else { + double n = moment.n; + kurtosis = + (n * (n + 1) * moment.getResult() - + 3 * moment.m2 * moment.m2 * (n - 1)) / + ((n - 1) * (n -2) * (n -3) * variance * variance); + } + } + return kurtosis; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + if (incMoment) { + moment.clear(); + } + } + + /** + * {@inheritDoc} + */ + public long getN() { + return moment.getN(); + } + + /* UnvariateStatistic Approach */ + + /** + * Returns the kurtosis of the entries in the specified portion of the + * input array. + * <p> + * See {@link Kurtosis} for details on the computing algorithm.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the kurtosis of the values or Double.NaN if length is less than 4 + * @throws MathIllegalArgumentException if the input array is null or the array + * index parameters are not valid + */ + @Override + public double evaluate(final double[] values,final int begin, final int length) + throws MathIllegalArgumentException { + // Initialize the kurtosis + double kurt = Double.NaN; + + if (test(values, begin, length) && length > 3) { + + // Compute the mean and standard deviation + Variance variance = new Variance(); + variance.incrementAll(values, begin, length); + double mean = variance.moment.m1; + double stdDev = FastMath.sqrt(variance.getResult()); + + // Sum the ^4 of the distance from the mean divided by the + // standard deviation + double accum3 = 0.0; + for (int i = begin; i < begin + length; i++) { + accum3 += FastMath.pow(values[i] - mean, 4.0); + } + accum3 /= FastMath.pow(stdDev, 4.0d); + + // Get N + double n0 = length; + + double coefficientOne = + (n0 * (n0 + 1)) / ((n0 - 1) * (n0 - 2) * (n0 - 3)); + double termTwo = + (3 * FastMath.pow(n0 - 1, 2.0)) / ((n0 - 2) * (n0 - 3)); + + // Calculate kurtosis + kurt = (coefficientOne * accum3) - termTwo; + } + return kurt; + } + + /** + * {@inheritDoc} + */ + @Override + public Kurtosis copy() { + Kurtosis result = new Kurtosis(); + // No try-catch because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Kurtosis to copy + * @param dest Kurtosis to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Kurtosis source, Kurtosis dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.moment = source.moment.copy(); + dest.incMoment = source.incMoment; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Mean.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Mean.java new file mode 100644 index 0000000..aac3d78 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Mean.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.WeightedEvaluation; +import org.apache.commons.math3.stat.descriptive.summary.Sum; +import org.apache.commons.math3.util.MathUtils; + +/** + * <p>Computes the arithmetic mean of a set of values. Uses the definitional + * formula:</p> + * <p> + * mean = sum(x_i) / n + * </p> + * <p>where <code>n</code> is the number of observations. + * </p> + * <p>When {@link #increment(double)} is used to add data incrementally from a + * stream of (unstored) values, the value of the statistic that + * {@link #getResult()} returns is computed using the following recursive + * updating algorithm: </p> + * <ol> + * <li>Initialize <code>m = </code> the first value</li> + * <li>For each additional value, update using <br> + * <code>m = m + (new value - m) / (number of observations)</code></li> + * </ol> + * <p> If {@link #evaluate(double[])} is used to compute the mean of an array + * of stored values, a two-pass, corrected algorithm is used, starting with + * the definitional formula computed using the array of stored values and then + * correcting this by adding the mean deviation of the data values from the + * arithmetic mean. See, e.g. "Comparison of Several Algorithms for Computing + * Sample Means and Variances," Robert F. Ling, Journal of the American + * Statistical Association, Vol. 69, No. 348 (Dec., 1974), pp. 859-866. </p> + * <p> + * Returns <code>Double.NaN</code> if the dataset is empty. Note that + * Double.NaN may also be returned if the input includes NaN and / or infinite + * values. + * </p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally. + * + */ +public class Mean extends AbstractStorelessUnivariateStatistic + implements Serializable, WeightedEvaluation { + + /** Serializable version identifier */ + private static final long serialVersionUID = -1296043746617791564L; + + /** First moment on which this statistic is based. */ + protected FirstMoment moment; + + /** + * Determines whether or not this statistic can be incremented or cleared. + * <p> + * Statistics based on (constructed from) external moments cannot + * be incremented or cleared.</p> + */ + protected boolean incMoment; + + /** Constructs a Mean. */ + public Mean() { + incMoment = true; + moment = new FirstMoment(); + } + + /** + * Constructs a Mean with an External Moment. + * + * @param m1 the moment + */ + public Mean(final FirstMoment m1) { + this.moment = m1; + incMoment = false; + } + + /** + * Copy constructor, creates a new {@code Mean} identical + * to the {@code original} + * + * @param original the {@code Mean} instance to copy + * @throws NullArgumentException if original is null + */ + public Mean(Mean original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + * <p>Note that when {@link #Mean(FirstMoment)} is used to + * create a Mean, this method does nothing. In that case, the + * FirstMoment should be incremented directly.</p> + */ + @Override + public void increment(final double d) { + if (incMoment) { + moment.increment(d); + } + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + if (incMoment) { + moment.clear(); + } + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return moment.m1; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return moment.getN(); + } + + /** + * Returns the arithmetic mean of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>IllegalArgumentException</code> if the array is null.</p> + * <p> + * See {@link Mean} for details on the computing algorithm.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the mean of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values,final int begin, final int length) + throws MathIllegalArgumentException { + if (test(values, begin, length)) { + Sum sum = new Sum(); + double sampleSize = length; + + // Compute initial estimate using definitional formula + double xbar = sum.evaluate(values, begin, length) / sampleSize; + + // Compute correction factor in second pass + double correction = 0; + for (int i = begin; i < begin + length; i++) { + correction += values[i] - xbar; + } + return xbar + (correction/sampleSize); + } + return Double.NaN; + } + + /** + * Returns the weighted arithmetic mean of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>IllegalArgumentException</code> if either array is null.</p> + * <p> + * See {@link Mean} for details on the computing algorithm. The two-pass algorithm + * described above is used here, with weights applied in computing both the original + * estimate and the correction factor.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li> + * </ul></p> + * + * @param values the input array + * @param weights the weights array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the mean of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, + final int begin, final int length) throws MathIllegalArgumentException { + if (test(values, weights, begin, length)) { + Sum sum = new Sum(); + + // Compute initial estimate using definitional formula + double sumw = sum.evaluate(weights,begin,length); + double xbarw = sum.evaluate(values, weights, begin, length) / sumw; + + // Compute correction factor in second pass + double correction = 0; + for (int i = begin; i < begin + length; i++) { + correction += weights[i] * (values[i] - xbarw); + } + return xbarw + (correction/sumw); + } + return Double.NaN; + } + + /** + * Returns the weighted arithmetic mean of the entries in the input array. + * <p> + * Throws <code>MathIllegalArgumentException</code> if either array is null.</p> + * <p> + * See {@link Mean} for details on the computing algorithm. The two-pass algorithm + * described above is used here, with weights applied in computing both the original + * estimate and the correction factor.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * </ul></p> + * + * @param values the input array + * @param weights the weights array + * @return the mean of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights) + throws MathIllegalArgumentException { + return evaluate(values, weights, 0, values.length); + } + + /** + * {@inheritDoc} + */ + @Override + public Mean copy() { + Mean result = new Mean(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Mean to copy + * @param dest Mean to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Mean source, Mean dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.incMoment = source.incMoment; + dest.moment = source.moment.copy(); + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SecondMoment.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SecondMoment.java new file mode 100644 index 0000000..12715c0 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SecondMoment.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes a statistic related to the Second Central Moment. Specifically, + * what is computed is the sum of squared deviations from the sample mean. + * <p> + * The following recursive updating formula is used:</p> + * <p> + * Let <ul> + * <li> dev = (current obs - previous mean) </li> + * <li> n = number of observations (including current obs) </li> + * </ul> + * Then</p> + * <p> + * new value = old value + dev^2 * (n -1) / n.</p> + * <p> + * Returns <code>Double.NaN</code> if no data values have been added and + * returns <code>0</code> if there is just one value in the data set. + * Note that Double.NaN may also be returned if the input includes NaN + * and / or infinite values.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class SecondMoment extends FirstMoment implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 3942403127395076445L; + + /** second moment of values that have been added */ + protected double m2; + + /** + * Create a SecondMoment instance + */ + public SecondMoment() { + super(); + m2 = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code SecondMoment} identical + * to the {@code original} + * + * @param original the {@code SecondMoment} instance to copy + * @throws NullArgumentException if original is null + */ + public SecondMoment(SecondMoment original) + throws NullArgumentException { + super(original); + this.m2 = original.m2; + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (n < 1) { + m1 = m2 = 0.0; + } + super.increment(d); + m2 += ((double) n - 1) * dev * nDev; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + super.clear(); + m2 = Double.NaN; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return m2; + } + + /** + * {@inheritDoc} + */ + @Override + public SecondMoment copy() { + SecondMoment result = new SecondMoment(); + // no try-catch or advertised NAE because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source SecondMoment to copy + * @param dest SecondMoment to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SecondMoment source, SecondMoment dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + FirstMoment.copy(source, dest); + dest.m2 = source.m2; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SemiVariance.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SemiVariance.java new file mode 100644 index 0000000..563119a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/SemiVariance.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * <p>Computes the semivariance of a set of values with respect to a given cutoff value. + * We define the <i>downside semivariance</i> of a set of values <code>x</code> + * against the <i>cutoff value</i> <code>cutoff</code> to be <br/> + * <code>Σ (x[i] - target)<sup>2</sup> / df</code> <br/> + * where the sum is taken over all <code>i</code> such that <code>x[i] < cutoff</code> + * and <code>df</code> is the length of <code>x</code> (non-bias-corrected) or + * one less than this number (bias corrected). The <i>upside semivariance</i> + * is defined similarly, with the sum taken over values of <code>x</code> that + * exceed the cutoff value.</p> + * + * <p>The cutoff value defaults to the mean, bias correction defaults to <code>true</code> + * and the "variance direction" (upside or downside) defaults to downside. The variance direction + * and bias correction may be set using property setters or their values can provided as + * parameters to {@link #evaluate(double[], double, Direction, boolean, int, int)}.</p> + * + * <p>If the input array is null, <code>evaluate</code> methods throw + * <code>IllegalArgumentException.</code> If the array has length 1, <code>0</code> + * is returned, regardless of the value of the <code>cutoff.</code> + * + * <p><strong>Note that this class is not intended to be threadsafe.</strong> If + * multiple threads access an instance of this class concurrently, and one or + * more of these threads invoke property setters, external synchronization must + * be provided to ensure correct results.</p> + * + * @since 2.1 + */ +public class SemiVariance extends AbstractUnivariateStatistic implements Serializable { + + /** + * The UPSIDE Direction is used to specify that the observations above the + * cutoff point will be used to calculate SemiVariance. + */ + public static final Direction UPSIDE_VARIANCE = Direction.UPSIDE; + + /** + * The DOWNSIDE Direction is used to specify that the observations below + * the cutoff point will be used to calculate SemiVariance + */ + public static final Direction DOWNSIDE_VARIANCE = Direction.DOWNSIDE; + + /** Serializable version identifier */ + private static final long serialVersionUID = -2653430366886024994L; + + /** + * Determines whether or not bias correction is applied when computing the + * value of the statisic. True means that bias is corrected. + */ + private boolean biasCorrected = true; + + /** + * Determines whether to calculate downside or upside SemiVariance. + */ + private Direction varianceDirection = Direction.DOWNSIDE; + + /** + * Constructs a SemiVariance with default (true) <code>biasCorrected</code> + * property and default (Downside) <code>varianceDirection</code> property. + */ + public SemiVariance() { + } + + /** + * Constructs a SemiVariance with the specified <code>biasCorrected</code> + * property and default (Downside) <code>varianceDirection</code> property. + * + * @param biasCorrected setting for bias correction - true means + * bias will be corrected and is equivalent to using the argumentless + * constructor + */ + public SemiVariance(final boolean biasCorrected) { + this.biasCorrected = biasCorrected; + } + + + /** + * Constructs a SemiVariance with the specified <code>Direction</code> property + * and default (true) <code>biasCorrected</code> property + * + * @param direction setting for the direction of the SemiVariance + * to calculate + */ + public SemiVariance(final Direction direction) { + this.varianceDirection = direction; + } + + + /** + * Constructs a SemiVariance with the specified <code>isBiasCorrected</code> + * property and the specified <code>Direction</code> property. + * + * @param corrected setting for bias correction - true means + * bias will be corrected and is equivalent to using the argumentless + * constructor + * + * @param direction setting for the direction of the SemiVariance + * to calculate + */ + public SemiVariance(final boolean corrected, final Direction direction) { + this.biasCorrected = corrected; + this.varianceDirection = direction; + } + + + /** + * Copy constructor, creates a new {@code SemiVariance} identical + * to the {@code original} + * + * @param original the {@code SemiVariance} instance to copy + * @throws NullArgumentException if original is null + */ + public SemiVariance(final SemiVariance original) throws NullArgumentException { + copy(original, this); + } + + + /** + * {@inheritDoc} + */ + @Override + public SemiVariance copy() { + SemiVariance result = new SemiVariance(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source SemiVariance to copy + * @param dest SemiVariance to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(final SemiVariance source, SemiVariance dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.biasCorrected = source.biasCorrected; + dest.varianceDirection = source.varianceDirection; + } + + /** + * <p>Returns the {@link SemiVariance} of the designated values against the mean, using + * instance properties varianceDirection and biasCorrection.</p> + * + * <p>Returns <code>NaN</code> if the array is empty and throws + * <code>IllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param start index of the first array element to include + * @param length the number of elements to include + * @return the SemiVariance + * @throws MathIllegalArgumentException if the parameters are not valid + * + */ + @Override + public double evaluate(final double[] values, final int start, final int length) + throws MathIllegalArgumentException { + double m = (new Mean()).evaluate(values, start, length); + return evaluate(values, m, varianceDirection, biasCorrected, 0, values.length); + } + + + /** + * This method calculates {@link SemiVariance} for the entire array against the mean, using + * the current value of the biasCorrection instance property. + * + * @param values the input array + * @param direction the {@link Direction} of the semivariance + * @return the SemiVariance + * @throws MathIllegalArgumentException if values is null + * + */ + public double evaluate(final double[] values, Direction direction) + throws MathIllegalArgumentException { + double m = (new Mean()).evaluate(values); + return evaluate (values, m, direction, biasCorrected, 0, values.length); + } + + /** + * <p>Returns the {@link SemiVariance} of the designated values against the cutoff, using + * instance properties variancDirection and biasCorrection.</p> + * + * <p>Returns <code>NaN</code> if the array is empty and throws + * <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param cutoff the reference point + * @return the SemiVariance + * @throws MathIllegalArgumentException if values is null + */ + public double evaluate(final double[] values, final double cutoff) + throws MathIllegalArgumentException { + return evaluate(values, cutoff, varianceDirection, biasCorrected, 0, values.length); + } + + /** + * <p>Returns the {@link SemiVariance} of the designated values against the cutoff in the + * given direction, using the current value of the biasCorrection instance property.</p> + * + * <p>Returns <code>NaN</code> if the array is empty and throws + * <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param cutoff the reference point + * @param direction the {@link Direction} of the semivariance + * @return the SemiVariance + * @throws MathIllegalArgumentException if values is null + */ + public double evaluate(final double[] values, final double cutoff, final Direction direction) + throws MathIllegalArgumentException { + return evaluate(values, cutoff, direction, biasCorrected, 0, values.length); + } + + + /** + * <p>Returns the {@link SemiVariance} of the designated values against the cutoff + * in the given direction with the provided bias correction.</p> + * + * <p>Returns <code>NaN</code> if the array is empty and throws + * <code>IllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param cutoff the reference point + * @param direction the {@link Direction} of the semivariance + * @param corrected the BiasCorrection flag + * @param start index of the first array element to include + * @param length the number of elements to include + * @return the SemiVariance + * @throws MathIllegalArgumentException if the parameters are not valid + * + */ + public double evaluate (final double[] values, final double cutoff, final Direction direction, + final boolean corrected, final int start, final int length) throws MathIllegalArgumentException { + + test(values, start, length); + if (values.length == 0) { + return Double.NaN; + } else { + if (values.length == 1) { + return 0.0; + } else { + final boolean booleanDirection = direction.getDirection(); + + double dev = 0.0; + double sumsq = 0.0; + for (int i = start; i < length; i++) { + if ((values[i] > cutoff) == booleanDirection) { + dev = values[i] - cutoff; + sumsq += dev * dev; + } + } + + if (corrected) { + return sumsq / (length - 1.0); + } else { + return sumsq / length; + } + } + } + } + + /** + * Returns true iff biasCorrected property is set to true. + * + * @return the value of biasCorrected. + */ + public boolean isBiasCorrected() { + return biasCorrected; + } + + /** + * Sets the biasCorrected property. + * + * @param biasCorrected new biasCorrected property value + */ + public void setBiasCorrected(boolean biasCorrected) { + this.biasCorrected = biasCorrected; + } + + /** + * Returns the varianceDirection property. + * + * @return the varianceDirection + */ + public Direction getVarianceDirection () { + return varianceDirection; + } + + /** + * Sets the variance direction + * + * @param varianceDirection the direction of the semivariance + */ + public void setVarianceDirection(Direction varianceDirection) { + this.varianceDirection = varianceDirection; + } + + /** + * The direction of the semivariance - either upside or downside. The direction + * is represented by boolean, with true corresponding to UPSIDE semivariance. + */ + public enum Direction { + /** + * The UPSIDE Direction is used to specify that the observations above the + * cutoff point will be used to calculate SemiVariance + */ + UPSIDE (true), + + /** + * The DOWNSIDE Direction is used to specify that the observations below + * the cutoff point will be used to calculate SemiVariance + */ + DOWNSIDE (false); + + /** + * boolean value UPSIDE <-> true + */ + private boolean direction; + + /** + * Create a Direction with the given value. + * + * @param b boolean value representing the Direction. True corresponds to UPSIDE. + */ + Direction (boolean b) { + direction = b; + } + + /** + * Returns the value of this Direction. True corresponds to UPSIDE. + * + * @return true if direction is UPSIDE; false otherwise + */ + boolean getDirection () { + return direction; + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Skewness.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Skewness.java new file mode 100644 index 0000000..b4703eb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Skewness.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes the skewness of the available values. + * <p> + * We use the following (unbiased) formula to define skewness:</p> + * <p> + * skewness = [n / (n -1) (n - 2)] sum[(x_i - mean)^3] / std^3 </p> + * <p> + * where n is the number of values, mean is the {@link Mean} and std is the + * {@link StandardDeviation} </p> + * <p> + * Note that this statistic is undefined for n < 3. <code>Double.Nan</code> + * is returned when there is not sufficient data to compute the statistic. + * Double.NaN may also be returned if the input includes NaN and / or + * infinite values.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally. </p> + * + */ +public class Skewness extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 7101857578996691352L; + + /** Third moment on which this statistic is based */ + protected ThirdMoment moment = null; + + /** + * Determines whether or not this statistic can be incremented or cleared. + * <p> + * Statistics based on (constructed from) external moments cannot + * be incremented or cleared.</p> + */ + protected boolean incMoment; + + /** + * Constructs a Skewness + */ + public Skewness() { + incMoment = true; + moment = new ThirdMoment(); + } + + /** + * Constructs a Skewness with an external moment + * @param m3 external moment + */ + public Skewness(final ThirdMoment m3) { + incMoment = false; + this.moment = m3; + } + + /** + * Copy constructor, creates a new {@code Skewness} identical + * to the {@code original} + * + * @param original the {@code Skewness} instance to copy + * @throws NullArgumentException if original is null + */ + public Skewness(Skewness original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + * <p>Note that when {@link #Skewness(ThirdMoment)} is used to + * create a Skewness, this method does nothing. In that case, the + * ThirdMoment should be incremented directly.</p> + */ + @Override + public void increment(final double d) { + if (incMoment) { + moment.increment(d); + } + } + + /** + * Returns the value of the statistic based on the values that have been added. + * <p> + * See {@link Skewness} for the definition used in the computation.</p> + * + * @return the skewness of the available values. + */ + @Override + public double getResult() { + + if (moment.n < 3) { + return Double.NaN; + } + double variance = moment.m2 / (moment.n - 1); + if (variance < 10E-20) { + return 0.0d; + } else { + double n0 = moment.getN(); + return (n0 * moment.m3) / + ((n0 - 1) * (n0 -2) * FastMath.sqrt(variance) * variance); + } + } + + /** + * {@inheritDoc} + */ + public long getN() { + return moment.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + if (incMoment) { + moment.clear(); + } + } + + /** + * Returns the Skewness of the entries in the specifed portion of the + * input array. + * <p> + * See {@link Skewness} for the definition used in the computation.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin the index of the first array element to include + * @param length the number of elements to include + * @return the skewness of the values or Double.NaN if length is less than + * 3 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values,final int begin, + final int length) throws MathIllegalArgumentException { + + // Initialize the skewness + double skew = Double.NaN; + + if (test(values, begin, length) && length > 2 ){ + Mean mean = new Mean(); + // Get the mean and the standard deviation + double m = mean.evaluate(values, begin, length); + + // Calc the std, this is implemented here instead + // of using the standardDeviation method eliminate + // a duplicate pass to get the mean + double accum = 0.0; + double accum2 = 0.0; + for (int i = begin; i < begin + length; i++) { + final double d = values[i] - m; + accum += d * d; + accum2 += d; + } + final double variance = (accum - (accum2 * accum2 / length)) / (length - 1); + + double accum3 = 0.0; + for (int i = begin; i < begin + length; i++) { + final double d = values[i] - m; + accum3 += d * d * d; + } + accum3 /= variance * FastMath.sqrt(variance); + + // Get N + double n0 = length; + + // Calculate skewness + skew = (n0 / ((n0 - 1) * (n0 - 2))) * accum3; + } + return skew; + } + + /** + * {@inheritDoc} + */ + @Override + public Skewness copy() { + Skewness result = new Skewness(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Skewness to copy + * @param dest Skewness to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Skewness source, Skewness dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.moment = new ThirdMoment(source.moment.copy()); + dest.incMoment = source.incMoment; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/StandardDeviation.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/StandardDeviation.java new file mode 100644 index 0000000..a6248c5 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/StandardDeviation.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes the sample standard deviation. The standard deviation + * is the positive square root of the variance. This implementation wraps a + * {@link Variance} instance. The <code>isBiasCorrected</code> property of the + * wrapped Variance instance is exposed, so that this class can be used to + * compute both the "sample standard deviation" (the square root of the + * bias-corrected "sample variance") or the "population standard deviation" + * (the square root of the non-bias-corrected "population variance"). See + * {@link Variance} for more information. + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class StandardDeviation extends AbstractStorelessUnivariateStatistic + implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 5728716329662425188L; + + /** Wrapped Variance instance */ + private Variance variance = null; + + /** + * Constructs a StandardDeviation. Sets the underlying {@link Variance} + * instance's <code>isBiasCorrected</code> property to true. + */ + public StandardDeviation() { + variance = new Variance(); + } + + /** + * Constructs a StandardDeviation from an external second moment. + * + * @param m2 the external moment + */ + public StandardDeviation(final SecondMoment m2) { + variance = new Variance(m2); + } + + /** + * Copy constructor, creates a new {@code StandardDeviation} identical + * to the {@code original} + * + * @param original the {@code StandardDeviation} instance to copy + * @throws NullArgumentException if original is null + */ + public StandardDeviation(StandardDeviation original) throws NullArgumentException { + copy(original, this); + } + + /** + * Contructs a StandardDeviation with the specified value for the + * <code>isBiasCorrected</code> property. If this property is set to + * <code>true</code>, the {@link Variance} used in computing results will + * use the bias-corrected, or "sample" formula. See {@link Variance} for + * details. + * + * @param isBiasCorrected whether or not the variance computation will use + * the bias-corrected formula + */ + public StandardDeviation(boolean isBiasCorrected) { + variance = new Variance(isBiasCorrected); + } + + /** + * Contructs a StandardDeviation with the specified value for the + * <code>isBiasCorrected</code> property and the supplied external moment. + * If <code>isBiasCorrected</code> is set to <code>true</code>, the + * {@link Variance} used in computing results will use the bias-corrected, + * or "sample" formula. See {@link Variance} for details. + * + * @param isBiasCorrected whether or not the variance computation will use + * the bias-corrected formula + * @param m2 the external moment + */ + public StandardDeviation(boolean isBiasCorrected, SecondMoment m2) { + variance = new Variance(isBiasCorrected, m2); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + variance.increment(d); + } + + /** + * {@inheritDoc} + */ + public long getN() { + return variance.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return FastMath.sqrt(variance.getResult()); + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + variance.clear(); + } + + /** + * Returns the Standard Deviation of the entries in the input array, or + * <code>Double.NaN</code> if the array is empty. + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @return the standard deviation of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null + */ + @Override + public double evaluate(final double[] values) throws MathIllegalArgumentException { + return FastMath.sqrt(variance.evaluate(values)); + } + + /** + * Returns the Standard Deviation of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample. </p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the standard deviation of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + return FastMath.sqrt(variance.evaluate(values, begin, length)); + } + + /** + * Returns the Standard Deviation of the entries in the specified portion of + * the input array, using the precomputed mean value. Returns + * <code>Double.NaN</code> if the designated subarray is empty. + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * The formula used assumes that the supplied mean value is the arithmetic + * mean of the sample data, not a known population parameter. This method + * is supplied only to save computation when the mean has already been + * computed.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param mean the precomputed mean value + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the standard deviation of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + public double evaluate(final double[] values, final double mean, + final int begin, final int length) throws MathIllegalArgumentException { + return FastMath.sqrt(variance.evaluate(values, mean, begin, length)); + } + + /** + * Returns the Standard Deviation of the entries in the input array, using + * the precomputed mean value. Returns + * <code>Double.NaN</code> if the designated subarray is empty. + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * The formula used assumes that the supplied mean value is the arithmetic + * mean of the sample data, not a known population parameter. This method + * is supplied only to save computation when the mean has already been + * computed.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param mean the precomputed mean value + * @return the standard deviation of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null + */ + public double evaluate(final double[] values, final double mean) + throws MathIllegalArgumentException { + return FastMath.sqrt(variance.evaluate(values, mean)); + } + + /** + * @return Returns the isBiasCorrected. + */ + public boolean isBiasCorrected() { + return variance.isBiasCorrected(); + } + + /** + * @param isBiasCorrected The isBiasCorrected to set. + */ + public void setBiasCorrected(boolean isBiasCorrected) { + variance.setBiasCorrected(isBiasCorrected); + } + + /** + * {@inheritDoc} + */ + @Override + public StandardDeviation copy() { + StandardDeviation result = new StandardDeviation(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source StandardDeviation to copy + * @param dest StandardDeviation to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(StandardDeviation source, StandardDeviation dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.variance = source.variance.copy(); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/ThirdMoment.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/ThirdMoment.java new file mode 100644 index 0000000..43a9ca1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/ThirdMoment.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.util.MathUtils; + + +/** + * Computes a statistic related to the Third Central Moment. Specifically, + * what is computed is the sum of cubed deviations from the sample mean. + * <p> + * The following recursive updating formula is used:</p> + * <p> + * Let <ul> + * <li> dev = (current obs - previous mean) </li> + * <li> m2 = previous value of {@link SecondMoment} </li> + * <li> n = number of observations (including current obs) </li> + * </ul> + * Then</p> + * <p> + * new value = old value - 3 * (dev/n) * m2 + (n-1) * (n -2) * (dev^3/n^2)</p> + * <p> + * Returns <code>Double.NaN</code> if no data values have been added and + * returns <code>0</code> if there is just one value in the data set. + * Note that Double.NaN may also be returned if the input includes NaN + * and / or infinite values.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +class ThirdMoment extends SecondMoment implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -7818711964045118679L; + + /** third moment of values that have been added */ + protected double m3; + + /** + * Square of deviation of most recently added value from previous first + * moment, normalized by previous sample size. Retained to prevent + * repeated computation in higher order moments. nDevSq = nDev * nDev. + */ + protected double nDevSq; + + /** + * Create a FourthMoment instance + */ + ThirdMoment() { + super(); + m3 = Double.NaN; + nDevSq = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code ThirdMoment} identical + * to the {@code original} + * + * @param original the {@code ThirdMoment} instance to copy + * @throws NullArgumentException if orginal is null + */ + ThirdMoment(ThirdMoment original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (n < 1) { + m3 = m2 = m1 = 0.0; + } + + double prevM2 = m2; + super.increment(d); + nDevSq = nDev * nDev; + double n0 = n; + m3 = m3 - 3.0 * nDev * prevM2 + (n0 - 1) * (n0 - 2) * nDevSq * dev; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return m3; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + super.clear(); + m3 = Double.NaN; + nDevSq = Double.NaN; + } + + /** + * {@inheritDoc} + */ + @Override + public ThirdMoment copy() { + ThirdMoment result = new ThirdMoment(); + // No try-catch or advertised exception because args are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source ThirdMoment to copy + * @param dest ThirdMoment to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(ThirdMoment source, ThirdMoment dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + SecondMoment.copy(source, dest); + dest.m3 = source.m3; + dest.nDevSq = source.nDevSq; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Variance.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Variance.java new file mode 100644 index 0000000..1ba48e9 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/Variance.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.WeightedEvaluation; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * Computes the variance of the available values. By default, the unbiased + * "sample variance" definitional formula is used: + * <p> + * variance = sum((x_i - mean)^2) / (n - 1) </p> + * <p> + * where mean is the {@link Mean} and <code>n</code> is the number + * of sample observations.</p> + * <p> + * The definitional formula does not have good numerical properties, so + * this implementation does not compute the statistic using the definitional + * formula. <ul> + * <li> The <code>getResult</code> method computes the variance using + * updating formulas based on West's algorithm, as described in + * <a href="http://doi.acm.org/10.1145/359146.359152"> Chan, T. F. and + * J. G. Lewis 1979, <i>Communications of the ACM</i>, + * vol. 22 no. 9, pp. 526-531.</a></li> + * <li> The <code>evaluate</code> methods leverage the fact that they have the + * full array of values in memory to execute a two-pass algorithm. + * Specifically, these methods use the "corrected two-pass algorithm" from + * Chan, Golub, Levesque, <i>Algorithms for Computing the Sample Variance</i>, + * American Statistician, vol. 37, no. 3 (1983) pp. 242-247.</li></ul> + * Note that adding values using <code>increment</code> or + * <code>incrementAll</code> and then executing <code>getResult</code> will + * sometimes give a different, less accurate, result than executing + * <code>evaluate</code> with the full array of values. The former approach + * should only be used when the full array of values is not available.</p> + * <p> + * The "population variance" ( sum((x_i - mean)^2) / n ) can also + * be computed using this statistic. The <code>isBiasCorrected</code> + * property determines whether the "population" or "sample" value is + * returned by the <code>evaluate</code> and <code>getResult</code> methods. + * To compute population variances, set this property to <code>false.</code> + * </p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Variance extends AbstractStorelessUnivariateStatistic implements Serializable, WeightedEvaluation { + + /** Serializable version identifier */ + private static final long serialVersionUID = -9111962718267217978L; + + /** SecondMoment is used in incremental calculation of Variance*/ + protected SecondMoment moment = null; + + /** + * Whether or not {@link #increment(double)} should increment + * the internal second moment. When a Variance is constructed with an + * external SecondMoment as a constructor parameter, this property is + * set to false and increments must be applied to the second moment + * directly. + */ + protected boolean incMoment = true; + + /** + * Whether or not bias correction is applied when computing the + * value of the statistic. True means that bias is corrected. See + * {@link Variance} for details on the formula. + */ + private boolean isBiasCorrected = true; + + /** + * Constructs a Variance with default (true) <code>isBiasCorrected</code> + * property. + */ + public Variance() { + moment = new SecondMoment(); + } + + /** + * Constructs a Variance based on an external second moment. + * When this constructor is used, the statistic may only be + * incremented via the moment, i.e., {@link #increment(double)} + * does nothing; whereas {@code m2.increment(value)} increments + * both {@code m2} and the Variance instance constructed from it. + * + * @param m2 the SecondMoment (Third or Fourth moments work + * here as well.) + */ + public Variance(final SecondMoment m2) { + incMoment = false; + this.moment = m2; + } + + /** + * Constructs a Variance with the specified <code>isBiasCorrected</code> + * property + * + * @param isBiasCorrected setting for bias correction - true means + * bias will be corrected and is equivalent to using the argumentless + * constructor + */ + public Variance(boolean isBiasCorrected) { + moment = new SecondMoment(); + this.isBiasCorrected = isBiasCorrected; + } + + /** + * Constructs a Variance with the specified <code>isBiasCorrected</code> + * property and the supplied external second moment. + * + * @param isBiasCorrected setting for bias correction - true means + * bias will be corrected + * @param m2 the SecondMoment (Third or Fourth moments work + * here as well.) + */ + public Variance(boolean isBiasCorrected, SecondMoment m2) { + incMoment = false; + this.moment = m2; + this.isBiasCorrected = isBiasCorrected; + } + + /** + * Copy constructor, creates a new {@code Variance} identical + * to the {@code original} + * + * @param original the {@code Variance} instance to copy + * @throws NullArgumentException if original is null + */ + public Variance(Variance original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + * <p>If all values are available, it is more accurate to use + * {@link #evaluate(double[])} rather than adding values one at a time + * using this method and then executing {@link #getResult}, since + * <code>evaluate</code> leverages the fact that is has the full + * list of values together to execute a two-pass algorithm. + * See {@link Variance}.</p> + * + * <p>Note also that when {@link #Variance(SecondMoment)} is used to + * create a Variance, this method does nothing. In that case, the + * SecondMoment should be incremented directly.</p> + */ + @Override + public void increment(final double d) { + if (incMoment) { + moment.increment(d); + } + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + if (moment.n == 0) { + return Double.NaN; + } else if (moment.n == 1) { + return 0d; + } else { + if (isBiasCorrected) { + return moment.m2 / (moment.n - 1d); + } else { + return moment.m2 / (moment.n); + } + } + } + + /** + * {@inheritDoc} + */ + public long getN() { + return moment.getN(); + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + if (incMoment) { + moment.clear(); + } + } + + /** + * Returns the variance of the entries in the input array, or + * <code>Double.NaN</code> if the array is empty. + * <p> + * See {@link Variance} for details on the computing algorithm.</p> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null + */ + @Override + public double evaluate(final double[] values) throws MathIllegalArgumentException { + if (values == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + return evaluate(values, 0, values.length); + } + + /** + * Returns the variance of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. Note that Double.NaN may also be returned if the input + * includes NaN and / or infinite values. + * <p> + * See {@link Variance} for details on the computing algorithm.</p> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + + double var = Double.NaN; + + if (test(values, begin, length)) { + clear(); + if (length == 1) { + var = 0.0; + } else if (length > 1) { + Mean mean = new Mean(); + double m = mean.evaluate(values, begin, length); + var = evaluate(values, m, begin, length); + } + } + return var; + } + + /** + * <p>Returns the weighted variance of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty.</p> + * <p> + * Uses the formula <pre> + * Σ(weights[i]*(values[i] - weightedMean)<sup>2</sup>)/(Σ(weights[i]) - 1) + * </pre> + * where weightedMean is the weighted mean</p> + * <p> + * This formula will not return the same result as the unweighted variance when all + * weights are equal, unless all weights are equal to 1. The formula assumes that + * weights are to be treated as "expansion values," as will be the case if for example + * the weights represent frequency counts. To normalize weights so that the denominator + * in the variance computation equals the length of the input vector minus one, use <pre> + * <code>evaluate(values, MathArrays.normalizeArray(weights, values.length)); </code> + * </pre> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>IllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li> + * </ul></p> + * <p> + * Does not change the internal state of the statistic.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if either array is null.</p> + * + * @param values the input array + * @param weights the weights array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the weighted variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, + final int begin, final int length) throws MathIllegalArgumentException { + + double var = Double.NaN; + + if (test(values, weights,begin, length)) { + clear(); + if (length == 1) { + var = 0.0; + } else if (length > 1) { + Mean mean = new Mean(); + double m = mean.evaluate(values, weights, begin, length); + var = evaluate(values, weights, m, begin, length); + } + } + return var; + } + + /** + * <p> + * Returns the weighted variance of the entries in the the input array.</p> + * <p> + * Uses the formula <pre> + * Σ(weights[i]*(values[i] - weightedMean)<sup>2</sup>)/(Σ(weights[i]) - 1) + * </pre> + * where weightedMean is the weighted mean</p> + * <p> + * This formula will not return the same result as the unweighted variance when all + * weights are equal, unless all weights are equal to 1. The formula assumes that + * weights are to be treated as "expansion values," as will be the case if for example + * the weights represent frequency counts. To normalize weights so that the denominator + * in the variance computation equals the length of the input vector minus one, use <pre> + * <code>evaluate(values, MathArrays.normalizeArray(weights, values.length)); </code> + * </pre> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * </ul></p> + * <p> + * Does not change the internal state of the statistic.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if either array is null.</p> + * + * @param values the input array + * @param weights the weights array + * @return the weighted variance of the values + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights) + throws MathIllegalArgumentException { + return evaluate(values, weights, 0, values.length); + } + + /** + * Returns the variance of the entries in the specified portion of + * the input array, using the precomputed mean value. Returns + * <code>Double.NaN</code> if the designated subarray is empty. + * <p> + * See {@link Variance} for details on the computing algorithm.</p> + * <p> + * The formula used assumes that the supplied mean value is the arithmetic + * mean of the sample data, not a known population parameter. This method + * is supplied only to save computation when the mean has already been + * computed.</p> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param mean the precomputed mean value + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + public double evaluate(final double[] values, final double mean, + final int begin, final int length) throws MathIllegalArgumentException { + + double var = Double.NaN; + + if (test(values, begin, length)) { + if (length == 1) { + var = 0.0; + } else if (length > 1) { + double accum = 0.0; + double dev = 0.0; + double accum2 = 0.0; + for (int i = begin; i < begin + length; i++) { + dev = values[i] - mean; + accum += dev * dev; + accum2 += dev; + } + double len = length; + if (isBiasCorrected) { + var = (accum - (accum2 * accum2 / len)) / (len - 1.0); + } else { + var = (accum - (accum2 * accum2 / len)) / len; + } + } + } + return var; + } + + /** + * Returns the variance of the entries in the input array, using the + * precomputed mean value. Returns <code>Double.NaN</code> if the array + * is empty. + * <p> + * See {@link Variance} for details on the computing algorithm.</p> + * <p> + * If <code>isBiasCorrected</code> is <code>true</code> the formula used + * assumes that the supplied mean value is the arithmetic mean of the + * sample data, not a known population parameter. If the mean is a known + * population parameter, or if the "population" version of the variance is + * desired, set <code>isBiasCorrected</code> to <code>false</code> before + * invoking this method.</p> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param mean the precomputed mean value + * @return the variance of the values or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if the array is null + */ + public double evaluate(final double[] values, final double mean) throws MathIllegalArgumentException { + return evaluate(values, mean, 0, values.length); + } + + /** + * Returns the weighted variance of the entries in the specified portion of + * the input array, using the precomputed weighted mean value. Returns + * <code>Double.NaN</code> if the designated subarray is empty. + * <p> + * Uses the formula <pre> + * Σ(weights[i]*(values[i] - mean)<sup>2</sup>)/(Σ(weights[i]) - 1) + * </pre></p> + * <p> + * The formula used assumes that the supplied mean value is the weighted arithmetic + * mean of the sample data, not a known population parameter. This method + * is supplied only to save computation when the mean has already been + * computed.</p> + * <p> + * This formula will not return the same result as the unweighted variance when all + * weights are equal, unless all weights are equal to 1. The formula assumes that + * weights are to be treated as "expansion values," as will be the case if for example + * the weights represent frequency counts. To normalize weights so that the denominator + * in the variance computation equals the length of the input vector minus one, use <pre> + * <code>evaluate(values, MathArrays.normalizeArray(weights, values.length), mean); </code> + * </pre> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li> + * </ul></p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param weights the weights array + * @param mean the precomputed weighted mean value + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, + final double mean, final int begin, final int length) + throws MathIllegalArgumentException { + + double var = Double.NaN; + + if (test(values, weights, begin, length)) { + if (length == 1) { + var = 0.0; + } else if (length > 1) { + double accum = 0.0; + double dev = 0.0; + double accum2 = 0.0; + for (int i = begin; i < begin + length; i++) { + dev = values[i] - mean; + accum += weights[i] * (dev * dev); + accum2 += weights[i] * dev; + } + + double sumWts = 0; + for (int i = begin; i < begin + length; i++) { + sumWts += weights[i]; + } + + if (isBiasCorrected) { + var = (accum - (accum2 * accum2 / sumWts)) / (sumWts - 1.0); + } else { + var = (accum - (accum2 * accum2 / sumWts)) / sumWts; + } + } + } + return var; + } + + /** + * <p>Returns the weighted variance of the values in the input array, using + * the precomputed weighted mean value.</p> + * <p> + * Uses the formula <pre> + * Σ(weights[i]*(values[i] - mean)<sup>2</sup>)/(Σ(weights[i]) - 1) + * </pre></p> + * <p> + * The formula used assumes that the supplied mean value is the weighted arithmetic + * mean of the sample data, not a known population parameter. This method + * is supplied only to save computation when the mean has already been + * computed.</p> + * <p> + * This formula will not return the same result as the unweighted variance when all + * weights are equal, unless all weights are equal to 1. The formula assumes that + * weights are to be treated as "expansion values," as will be the case if for example + * the weights represent frequency counts. To normalize weights so that the denominator + * in the variance computation equals the length of the input vector minus one, use <pre> + * <code>evaluate(values, MathArrays.normalizeArray(weights, values.length), mean); </code> + * </pre> + * <p> + * Returns 0 for a single-value (i.e. length = 1) sample.</p> + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * </ul></p> + * <p> + * Does not change the internal state of the statistic.</p> + * + * @param values the input array + * @param weights the weights array + * @param mean the precomputed weighted mean value + * @return the variance of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, final double mean) + throws MathIllegalArgumentException { + return evaluate(values, weights, mean, 0, values.length); + } + + /** + * @return Returns the isBiasCorrected. + */ + public boolean isBiasCorrected() { + return isBiasCorrected; + } + + /** + * @param biasCorrected The isBiasCorrected to set. + */ + public void setBiasCorrected(boolean biasCorrected) { + this.isBiasCorrected = biasCorrected; + } + + /** + * {@inheritDoc} + */ + @Override + public Variance copy() { + Variance result = new Variance(); + // No try-catch or advertised exception because parameters are guaranteed non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Variance to copy + * @param dest Variance to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Variance source, Variance dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.moment = source.moment.copy(); + dest.isBiasCorrected = source.isBiasCorrected; + dest.incMoment = source.incMoment; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialCovariance.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialCovariance.java new file mode 100644 index 0000000..7f6f903 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialCovariance.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; + +/** + * Returns the covariance matrix of the available vectors. + * @since 1.2 + */ +public class VectorialCovariance implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 4118372414238930270L; + + /** Sums for each component. */ + private final double[] sums; + + /** Sums of products for each component. */ + private final double[] productsSums; + + /** Indicator for bias correction. */ + private final boolean isBiasCorrected; + + /** Number of vectors in the sample. */ + private long n; + + /** Constructs a VectorialCovariance. + * @param dimension vectors dimension + * @param isBiasCorrected if true, computed the unbiased sample covariance, + * otherwise computes the biased population covariance + */ + public VectorialCovariance(int dimension, boolean isBiasCorrected) { + sums = new double[dimension]; + productsSums = new double[dimension * (dimension + 1) / 2]; + n = 0; + this.isBiasCorrected = isBiasCorrected; + } + + /** + * Add a new vector to the sample. + * @param v vector to add + * @throws DimensionMismatchException if the vector does not have the right dimension + */ + public void increment(double[] v) throws DimensionMismatchException { + if (v.length != sums.length) { + throw new DimensionMismatchException(v.length, sums.length); + } + int k = 0; + for (int i = 0; i < v.length; ++i) { + sums[i] += v[i]; + for (int j = 0; j <= i; ++j) { + productsSums[k++] += v[i] * v[j]; + } + } + n++; + } + + /** + * Get the covariance matrix. + * @return covariance matrix + */ + public RealMatrix getResult() { + + int dimension = sums.length; + RealMatrix result = MatrixUtils.createRealMatrix(dimension, dimension); + + if (n > 1) { + double c = 1.0 / (n * (isBiasCorrected ? (n - 1) : n)); + int k = 0; + for (int i = 0; i < dimension; ++i) { + for (int j = 0; j <= i; ++j) { + double e = c * (n * productsSums[k++] - sums[i] * sums[j]); + result.setEntry(i, j, e); + result.setEntry(j, i, e); + } + } + } + + return result; + + } + + /** + * Get the number of vectors in the sample. + * @return number of vectors in the sample + */ + public long getN() { + return n; + } + + /** + * Clears the internal state of the Statistic + */ + public void clear() { + n = 0; + Arrays.fill(sums, 0.0); + Arrays.fill(productsSums, 0.0); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + (isBiasCorrected ? 1231 : 1237); + result = prime * result + (int) (n ^ (n >>> 32)); + result = prime * result + Arrays.hashCode(productsSums); + result = prime * result + Arrays.hashCode(sums); + return result; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof VectorialCovariance)) { + return false; + } + VectorialCovariance other = (VectorialCovariance) obj; + if (isBiasCorrected != other.isBiasCorrected) { + return false; + } + if (n != other.n) { + return false; + } + if (!Arrays.equals(productsSums, other.productsSums)) { + return false; + } + if (!Arrays.equals(sums, other.sums)) { + return false; + } + return true; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialMean.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialMean.java new file mode 100644 index 0000000..e06b3bc --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/VectorialMean.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.moment; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.commons.math3.exception.DimensionMismatchException; + +/** + * Returns the arithmetic mean of the available vectors. + * @since 1.2 + */ +public class VectorialMean implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 8223009086481006892L; + + /** Means for each component. */ + private final Mean[] means; + + /** Constructs a VectorialMean. + * @param dimension vectors dimension + */ + public VectorialMean(int dimension) { + means = new Mean[dimension]; + for (int i = 0; i < dimension; ++i) { + means[i] = new Mean(); + } + } + + /** + * Add a new vector to the sample. + * @param v vector to add + * @throws DimensionMismatchException if the vector does not have the right dimension + */ + public void increment(double[] v) throws DimensionMismatchException { + if (v.length != means.length) { + throw new DimensionMismatchException(v.length, means.length); + } + for (int i = 0; i < v.length; ++i) { + means[i].increment(v[i]); + } + } + + /** + * Get the mean vector. + * @return mean vector + */ + public double[] getResult() { + double[] result = new double[means.length]; + for (int i = 0; i < result.length; ++i) { + result[i] = means[i].getResult(); + } + return result; + } + + /** + * Get the number of vectors in the sample. + * @return number of vectors in the sample + */ + public long getN() { + return (means.length == 0) ? 0 : means[0].getN(); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(means); + return result; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof VectorialMean)) { + return false; + } + VectorialMean other = (VectorialMean) obj; + if (!Arrays.equals(means, other.means)) { + return false; + } + return true; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/moment/package-info.java b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/package-info.java new file mode 100644 index 0000000..e23ead7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/moment/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Summary statistics based on moments. + */ +package org.apache.commons.math3.stat.descriptive.moment; diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/package-info.java b/src/main/java/org/apache/commons/math3/stat/descriptive/package-info.java new file mode 100644 index 0000000..92fa5b3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/package-info.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Generic univariate summary statistic objects. + * + * <h3>UnivariateStatistic API Usage Examples:</h3> + * + * <h4>UnivariateStatistic:</h4> + * <code>/∗ evaluation approach ∗/<br/> + * double[] values = new double[] { 1, 2, 3, 4, 5 };<br/> + * <span style="font-weight: bold;">UnivariateStatistic stat = new Mean();</span><br/> + * out.println("mean = " + <span style="font-weight: bold;">stat.evaluate(values)</span>);<br/> + * </code> + * + * <h4>StorelessUnivariateStatistic:</h4> + * <code>/∗ incremental approach ∗/<br/> + * double[] values = new double[] { 1, 2, 3, 4, 5 };<br/> + * <span style="font-weight: bold;">StorelessUnivariateStatistic stat = new Mean();</span><br/> + * out.println("mean before adding a value is NaN = " + <span style="font-weight: bold;">stat.getResult()</span>);<br/> + * for (int i = 0; i < values.length; i++) {<br/> + * <span style="font-weight: bold;">stat.increment(values[i]);</span><br/> + * out.println("current mean = " + <span style="font-weight: bold;">stat2.getResult()</span>);<br/> + * }<br/> + * <span style="font-weight: bold;"> stat.clear();</span><br/> + * out.println("mean after clear is NaN = " + <span style="font-weight: bold;">stat.getResult()</span>); + * </code> + * + */ +package org.apache.commons.math3.stat.descriptive; diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Max.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Max.java new file mode 100644 index 0000000..75f145f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Max.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.rank; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the maximum of the available values. + * <p> + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> + * (i.e. <code>NaN</code> values have no impact on the value of the statistic).</li> + * <li>If any of the values equals <code>Double.POSITIVE_INFINITY</code>, + * the result is <code>Double.POSITIVE_INFINITY.</code></li> + * </ul></p> +* <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Max extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -5593383832225844641L; + + /** Number of values that have been added */ + private long n; + + /** Current value of the statistic */ + private double value; + + /** + * Create a Max instance + */ + public Max() { + n = 0; + value = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code Max} identical + * to the {@code original} + * + * @param original the {@code Max} instance to copy + * @throws NullArgumentException if original is null + */ + public Max(Max original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (d > value || Double.isNaN(value)) { + value = d; + } + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = Double.NaN; + n = 0; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * Returns the maximum of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null or + * the array index parameters are not valid.</p> + * <p> + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> + * (i.e. <code>NaN</code> values have no impact on the value of the statistic).</li> + * <li>If any of the values equals <code>Double.POSITIVE_INFINITY</code>, + * the result is <code>Double.POSITIVE_INFINITY.</code></li> + * </ul></p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the maximum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + double max = Double.NaN; + if (test(values, begin, length)) { + max = values[begin]; + for (int i = begin; i < begin + length; i++) { + if (!Double.isNaN(values[i])) { + max = (max > values[i]) ? max : values[i]; + } + } + } + return max; + } + + /** + * {@inheritDoc} + */ + @Override + public Max copy() { + Max result = new Max(); + // No try-catch or advertised exception because args are non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Max to copy + * @param dest Max to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Max source, Max dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Median.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Median.java new file mode 100644 index 0000000..6350a0b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Median.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.rank; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.ranking.NaNStrategy; +import org.apache.commons.math3.util.KthSelector; + + +/** + * Returns the median of the available values. This is the same as the 50th percentile. + * See {@link Percentile} for a description of the algorithm used. + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Median extends Percentile implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -3961477041290915687L; + + /** Fixed quantile. */ + private static final double FIXED_QUANTILE_50 = 50.0; + + /** + * Default constructor. + */ + public Median() { + // No try-catch or advertised exception - arg is valid + super(FIXED_QUANTILE_50); + } + + /** + * Copy constructor, creates a new {@code Median} identical + * to the {@code original} + * + * @param original the {@code Median} instance to copy + * @throws NullArgumentException if original is null + */ + public Median(Median original) throws NullArgumentException { + super(original); + } + + /** + * Constructs a Median with the specific {@link EstimationType}, {@link NaNStrategy} and {@link PivotingStrategy}. + * + * @param estimationType one of the percentile {@link EstimationType estimation types} + * @param nanStrategy one of {@link NaNStrategy} to handle with NaNs + * @param kthSelector {@link KthSelector} to use for pivoting during search + * @throws MathIllegalArgumentException if p is not within (0,100] + * @throws NullArgumentException if type or NaNStrategy passed is null + */ + private Median(final EstimationType estimationType, final NaNStrategy nanStrategy, + final KthSelector kthSelector) + throws MathIllegalArgumentException { + super(FIXED_QUANTILE_50, estimationType, nanStrategy, kthSelector); + } + + /** {@inheritDoc} */ + @Override + public Median withEstimationType(final EstimationType newEstimationType) { + return new Median(newEstimationType, getNaNStrategy(), getKthSelector()); + } + + /** {@inheritDoc} */ + @Override + public Median withNaNStrategy(final NaNStrategy newNaNStrategy) { + return new Median(getEstimationType(), newNaNStrategy, getKthSelector()); + } + + /** {@inheritDoc} */ + @Override + public Median withKthSelector(final KthSelector newKthSelector) { + return new Median(getEstimationType(), getNaNStrategy(), newKthSelector); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Min.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Min.java new file mode 100644 index 0000000..c87e6f1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Min.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.rank; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the minimum of the available values. + * <p> + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> + * (i.e. <code>NaN</code> values have no impact on the value of the statistic).</li> + * <li>If any of the values equals <code>Double.NEGATIVE_INFINITY</code>, + * the result is <code>Double.NEGATIVE_INFINITY.</code></li> + * </ul></p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Min extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -2941995784909003131L; + + /**Number of values that have been added */ + private long n; + + /**Current value of the statistic */ + private double value; + + /** + * Create a Min instance + */ + public Min() { + n = 0; + value = Double.NaN; + } + + /** + * Copy constructor, creates a new {@code Min} identical + * to the {@code original} + * + * @param original the {@code Min} instance to copy + * @throws NullArgumentException if original is null + */ + public Min(Min original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + if (d < value || Double.isNaN(value)) { + value = d; + } + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = Double.NaN; + n = 0; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * Returns the minimum of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null or + * the array index parameters are not valid.</p> + * <p> + * <ul> + * <li>The result is <code>NaN</code> iff all values are <code>NaN</code> + * (i.e. <code>NaN</code> values have no impact on the value of the statistic).</li> + * <li>If any of the values equals <code>Double.NEGATIVE_INFINITY</code>, + * the result is <code>Double.NEGATIVE_INFINITY.</code></li> + * </ul> </p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the minimum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values,final int begin, final int length) + throws MathIllegalArgumentException { + double min = Double.NaN; + if (test(values, begin, length)) { + min = values[begin]; + for (int i = begin; i < begin + length; i++) { + if (!Double.isNaN(values[i])) { + min = (min < values[i]) ? min : values[i]; + } + } + } + return min; + } + + /** + * {@inheritDoc} + */ + @Override + public Min copy() { + Min result = new Min(); + // No try-catch or advertised exception - args are non-null + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Min to copy + * @param dest Min to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Min source, Min dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/PSquarePercentile.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/PSquarePercentile.java new file mode 100644 index 0000000..b8bc274 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/PSquarePercentile.java @@ -0,0 +1,997 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.rank; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; +import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.interpolation.LinearInterpolator; +import org.apache.commons.math3.analysis.interpolation.NevilleInterpolator; +import org.apache.commons.math3.analysis.interpolation.UnivariateInterpolator; +import org.apache.commons.math3.exception.InsufficientDataException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.StorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.Precision; + +/** + * A {@link StorelessUnivariateStatistic} estimating percentiles using the + * <ahref=http://www.cs.wustl.edu/~jain/papers/ftp/psqr.pdf>P<SUP>2</SUP></a> + * Algorithm as explained by <a href=http://www.cse.wustl.edu/~jain/>Raj + * Jain</a> and Imrich Chlamtac in + * <a href=http://www.cse.wustl.edu/~jain/papers/psqr.htm>P<SUP>2</SUP> Algorithm + * for Dynamic Calculation of Quantiles and Histogram Without Storing + * Observations</a>. + * <p> + * Note: This implementation is not synchronized and produces an approximate + * result. For small samples, where data can be stored and processed in memory, + * {@link Percentile} should be used.</p> + * + */ +public class PSquarePercentile extends AbstractStorelessUnivariateStatistic + implements StorelessUnivariateStatistic, Serializable { + + /** + * The maximum array size used for psquare algorithm + */ + private static final int PSQUARE_CONSTANT = 5; + + /** + * A Default quantile needed in case if user prefers to use default no + * argument constructor. + */ + private static final double DEFAULT_QUANTILE_DESIRED = 50d; + + /** + * Serial ID + */ + private static final long serialVersionUID = 2283912083175715479L; + + /** + * A decimal formatter for print convenience + */ + private static final DecimalFormat DECIMAL_FORMAT = new DecimalFormat( + "00.00"); + + /** + * Initial list of 5 numbers corresponding to 5 markers. <b>NOTE:</b>watch + * out for the add methods that are overloaded + */ + private final List<Double> initialFive = new FixedCapacityList<Double>( + PSQUARE_CONSTANT); + + /** + * The quantile needed should be in range of 0-1. The constructor + * {@link #PSquarePercentile(double)} ensures that passed in percentile is + * divided by 100. + */ + private final double quantile; + + /** + * lastObservation is the last observation value/input sample. No need to + * serialize + */ + private transient double lastObservation; + + /** + * Markers is the marker collection object which comes to effect + * only after 5 values are inserted + */ + private PSquareMarkers markers = null; + + /** + * Computed p value (i,e percentile value of data set hither to received) + */ + private double pValue = Double.NaN; + + /** + * Counter to count the values/observations accepted into this data set + */ + private long countOfObservations; + + /** + * Constructs a PSquarePercentile with the specific percentile value. + * @param p the percentile + * @throws OutOfRangeException if p is not greater than 0 and less + * than or equal to 100 + */ + public PSquarePercentile(final double p) { + if (p > 100 || p < 0) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_RANGE, + p, 0, 100); + } + this.quantile = p / 100d;// always set it within (0,1] + } + + /** + * Default constructor that assumes a {@link #DEFAULT_QUANTILE_DESIRED + * default quantile} needed + */ + PSquarePercentile() { + this(DEFAULT_QUANTILE_DESIRED); + } + + /** + * {@inheritDoc} + */ + @Override + public int hashCode() { + double result = getResult(); + result = Double.isNaN(result) ? 37 : result; + final double markersHash = markers == null ? 0 : markers.hashCode(); + final double[] toHash = {result, quantile, markersHash, countOfObservations}; + return Arrays.hashCode(toHash); + } + + /** + * Returns true iff {@code o} is a {@code PSquarePercentile} returning the + * same values as this for {@code getResult()} and {@code getN()} and also + * having equal markers + * + * @param o object to compare + * @return true if {@code o} is a {@code PSquarePercentile} with + * equivalent internal state + */ + @Override + public boolean equals(Object o) { + boolean result = false; + if (this == o) { + result = true; + } else if (o != null && o instanceof PSquarePercentile) { + PSquarePercentile that = (PSquarePercentile) o; + boolean isNotNull = markers != null && that.markers != null; + boolean isNull = markers == null && that.markers == null; + result = isNotNull ? markers.equals(that.markers) : isNull; + // markers as in the case of first + // five observations + result = result && getN() == that.getN(); + } + return result; + } + + /** + * {@inheritDoc}The internal state updated due to the new value in this + * context is basically of the marker positions and computation of the + * approximate quantile. + * + * @param observation the observation currently being added. + */ + @Override + public void increment(final double observation) { + // Increment counter + countOfObservations++; + + // Store last observation + this.lastObservation = observation; + + // 0. Use Brute force for <5 + if (markers == null) { + if (initialFive.add(observation)) { + Collections.sort(initialFive); + pValue = + initialFive + .get((int) (quantile * (initialFive.size() - 1))); + return; + } + // 1. Initialize once after 5th observation + markers = newMarkers(initialFive, quantile); + } + // 2. process a Data Point and return pValue + pValue = markers.processDataPoint(observation); + } + + /** + * Returns a string containing the last observation, the current estimate + * of the quantile and all markers. + * + * @return string representation of state data + */ + @Override + public String toString() { + + if (markers == null) { + return String.format("obs=%s pValue=%s", + DECIMAL_FORMAT.format(lastObservation), + DECIMAL_FORMAT.format(pValue)); + } else { + return String.format("obs=%s markers=%s", + DECIMAL_FORMAT.format(lastObservation), markers.toString()); + } + } + + /** + * {@inheritDoc} + */ + public long getN() { + return countOfObservations; + } + + /** + * {@inheritDoc} + */ + @Override + public StorelessUnivariateStatistic copy() { + // multiply quantile by 100 now as anyway constructor divides it by 100 + PSquarePercentile copy = new PSquarePercentile(100d * quantile); + + if (markers != null) { + copy.markers = (PSquareMarkers) markers.clone(); + } + copy.countOfObservations = countOfObservations; + copy.pValue = pValue; + copy.initialFive.clear(); + copy.initialFive.addAll(initialFive); + return copy; + } + + /** + * Returns the quantile estimated by this statistic in the range [0.0-1.0] + * + * @return quantile estimated by {@link #getResult()} + */ + public double quantile() { + return quantile; + } + + /** + * {@inheritDoc}. This basically clears all the markers, the + * initialFive list and sets countOfObservations to 0. + */ + @Override + public void clear() { + markers = null; + initialFive.clear(); + countOfObservations = 0L; + pValue = Double.NaN; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + if (Double.compare(quantile, 1d) == 0) { + pValue = maximum(); + } else if (Double.compare(quantile, 0d) == 0) { + pValue = minimum(); + } + return pValue; + } + + /** + * @return maximum in the data set added to this statistic + */ + private double maximum() { + double val = Double.NaN; + if (markers != null) { + val = markers.height(PSQUARE_CONSTANT); + } else if (!initialFive.isEmpty()) { + val = initialFive.get(initialFive.size() - 1); + } + return val; + } + + /** + * @return minimum in the data set added to this statistic + */ + private double minimum() { + double val = Double.NaN; + if (markers != null) { + val = markers.height(1); + } else if (!initialFive.isEmpty()) { + val = initialFive.get(0); + } + return val; + } + + /** + * Markers is an encapsulation of the five markers/buckets as indicated in + * the original works. + */ + private static class Markers implements PSquareMarkers, Serializable { + /** + * Serial version id + */ + private static final long serialVersionUID = 1L; + + /** Low marker index */ + private static final int LOW = 2; + + /** High marker index */ + private static final int HIGH = 4; + + /** + * Array of 5+1 Markers (The first marker is dummy just so we + * can match the rest of indexes [1-5] indicated in the original works + * which follows unit based index) + */ + private final Marker[] markerArray; + + /** + * Kth cell belonging to [1-5] of the markerArray. No need for + * this to be serialized + */ + private transient int k = -1; + + /** + * Constructor + * + * @param theMarkerArray marker array to be used + */ + private Markers(final Marker[] theMarkerArray) { + MathUtils.checkNotNull(theMarkerArray); + markerArray = theMarkerArray; + for (int i = 1; i < PSQUARE_CONSTANT; i++) { + markerArray[i].previous(markerArray[i - 1]) + .next(markerArray[i + 1]).index(i); + } + markerArray[0].previous(markerArray[0]).next(markerArray[1]) + .index(0); + markerArray[5].previous(markerArray[4]).next(markerArray[5]) + .index(5); + } + + /** + * Constructor + * + * @param initialFive elements required to build Marker + * @param p quantile required to be computed + */ + private Markers(final List<Double> initialFive, final double p) { + this(createMarkerArray(initialFive, p)); + } + + /** + * Creates a marker array using initial five elements and a quantile + * + * @param initialFive list of initial five elements + * @param p the pth quantile + * @return Marker array + */ + private static Marker[] createMarkerArray( + final List<Double> initialFive, final double p) { + final int countObserved = + initialFive == null ? -1 : initialFive.size(); + if (countObserved < PSQUARE_CONSTANT) { + throw new InsufficientDataException( + LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, + countObserved, PSQUARE_CONSTANT); + } + Collections.sort(initialFive); + return new Marker[] { + new Marker(),// Null Marker + new Marker(initialFive.get(0), 1, 0, 1), + new Marker(initialFive.get(1), 1 + 2 * p, p / 2, 2), + new Marker(initialFive.get(2), 1 + 4 * p, p, 3), + new Marker(initialFive.get(3), 3 + 2 * p, (1 + p) / 2, 4), + new Marker(initialFive.get(4), 5, 1, 5) }; + } + + /** + * {@inheritDoc} + */ + @Override + public int hashCode() { + return Arrays.deepHashCode(markerArray); + } + + /** + * {@inheritDoc}.This equals method basically checks for marker array to + * be deep equals. + * + * @param o is the other object + * @return true if the object compares with this object are equivalent + */ + @Override + public boolean equals(Object o) { + boolean result = false; + if (this == o) { + result = true; + } else if (o != null && o instanceof Markers) { + Markers that = (Markers) o; + result = Arrays.deepEquals(markerArray, that.markerArray); + } + return result; + } + + /** + * Process a data point + * + * @param inputDataPoint is the data point passed + * @return computed percentile + */ + public double processDataPoint(final double inputDataPoint) { + + // 1. Find cell and update minima and maxima + final int kthCell = findCellAndUpdateMinMax(inputDataPoint); + + // 2. Increment positions + incrementPositions(1, kthCell + 1, 5); + + // 2a. Update desired position with increments + updateDesiredPositions(); + + // 3. Adjust heights of m[2-4] if necessary + adjustHeightsOfMarkers(); + + // 4. Return percentile + return getPercentileValue(); + } + + /** + * Returns the percentile computed thus far. + * + * @return height of mid point marker + */ + public double getPercentileValue() { + return height(3); + } + + /** + * Finds the cell where the input observation / value fits. + * + * @param observation the input value to be checked for + * @return kth cell (of the markers ranging from 1-5) where observed + * sample fits + */ + private int findCellAndUpdateMinMax(final double observation) { + k = -1; + if (observation < height(1)) { + markerArray[1].markerHeight = observation; + k = 1; + } else if (observation < height(2)) { + k = 1; + } else if (observation < height(3)) { + k = 2; + } else if (observation < height(4)) { + k = 3; + } else if (observation <= height(5)) { + k = 4; + } else { + markerArray[5].markerHeight = observation; + k = 4; + } + return k; + } + + /** + * Adjust marker heights by setting quantile estimates to middle markers. + */ + private void adjustHeightsOfMarkers() { + for (int i = LOW; i <= HIGH; i++) { + estimate(i); + } + } + + /** + * {@inheritDoc} + */ + public double estimate(final int index) { + if (index < LOW || index > HIGH) { + throw new OutOfRangeException(index, LOW, HIGH); + } + return markerArray[index].estimate(); + } + + /** + * Increment positions by d. Refer to algorithm paper for the + * definition of d. + * + * @param d The increment value for the position + * @param startIndex start index of the marker array + * @param endIndex end index of the marker array + */ + private void incrementPositions(final int d, final int startIndex, + final int endIndex) { + for (int i = startIndex; i <= endIndex; i++) { + markerArray[i].incrementPosition(d); + } + } + + /** + * Desired positions incremented by bucket width. The bucket width is + * basically the desired increments. + */ + private void updateDesiredPositions() { + for (int i = 1; i < markerArray.length; i++) { + markerArray[i].updateDesiredPosition(); + } + } + + /** + * Sets previous and next markers after default read is done. + * + * @param anInputStream the input stream to be deserialized + * @throws ClassNotFoundException thrown when a desired class not found + * @throws IOException thrown due to any io errors + */ + private void readObject(ObjectInputStream anInputStream) + throws ClassNotFoundException, IOException { + // always perform the default de-serialization first + anInputStream.defaultReadObject(); + // Build links + for (int i = 1; i < PSQUARE_CONSTANT; i++) { + markerArray[i].previous(markerArray[i - 1]) + .next(markerArray[i + 1]).index(i); + } + markerArray[0].previous(markerArray[0]).next(markerArray[1]) + .index(0); + markerArray[5].previous(markerArray[4]).next(markerArray[5]) + .index(5); + } + + /** + * Return marker height given index + * + * @param markerIndex index of marker within (1,6) + * @return marker height + */ + public double height(final int markerIndex) { + if (markerIndex >= markerArray.length || markerIndex <= 0) { + throw new OutOfRangeException(markerIndex, 1, + markerArray.length); + } + return markerArray[markerIndex].markerHeight; + } + + /** + * {@inheritDoc}.Clone Markers + * + * @return cloned object + */ + @Override + public Object clone() { + return new Markers(new Marker[] { new Marker(), + (Marker) markerArray[1].clone(), + (Marker) markerArray[2].clone(), + (Marker) markerArray[3].clone(), + (Marker) markerArray[4].clone(), + (Marker) markerArray[5].clone() }); + + } + + /** + * Returns string representation of the Marker array. + * + * @return Markers as a string + */ + @Override + public String toString() { + return String.format("m1=[%s],m2=[%s],m3=[%s],m4=[%s],m5=[%s]", + markerArray[1].toString(), markerArray[2].toString(), + markerArray[3].toString(), markerArray[4].toString(), + markerArray[5].toString()); + } + + } + + /** + * The class modeling the attributes of the marker of the P-square algorithm + */ + private static class Marker implements Serializable, Cloneable { + + /** + * Serial Version ID + */ + private static final long serialVersionUID = -3575879478288538431L; + + /** + * The marker index which is just a serial number for the marker in the + * marker array of 5+1. + */ + private int index; + + /** + * The integral marker position. Refer to the variable n in the original + * works. + */ + private double intMarkerPosition; + + /** + * Desired marker position. Refer to the variable n' in the original + * works. + */ + private double desiredMarkerPosition; + + /** + * Marker height or the quantile. Refer to the variable q in the + * original works. + */ + private double markerHeight; + + /** + * Desired marker increment. Refer to the variable dn' in the original + * works. + */ + private double desiredMarkerIncrement; + + /** + * Next and previous markers for easy linked navigation in loops. this + * is not serialized as they can be rebuilt during deserialization. + */ + private transient Marker next; + + /** + * The previous marker links + */ + private transient Marker previous; + + /** + * Nonlinear interpolator + */ + private final UnivariateInterpolator nonLinear = + new NevilleInterpolator(); + + /** + * Linear interpolator which is not serializable + */ + private transient UnivariateInterpolator linear = + new LinearInterpolator(); + + /** + * Default constructor + */ + private Marker() { + this.next = this.previous = this; + } + + /** + * Constructor of the marker with parameters + * + * @param heightOfMarker represent the quantile value + * @param makerPositionDesired represent the desired marker position + * @param markerPositionIncrement represent increments for position + * @param markerPositionNumber represent the position number of marker + */ + private Marker(double heightOfMarker, double makerPositionDesired, + double markerPositionIncrement, double markerPositionNumber) { + this(); + this.markerHeight = heightOfMarker; + this.desiredMarkerPosition = makerPositionDesired; + this.desiredMarkerIncrement = markerPositionIncrement; + this.intMarkerPosition = markerPositionNumber; + } + + /** + * Sets the previous marker. + * + * @param previousMarker the previous marker to the current marker in + * the array of markers + * @return this instance + */ + private Marker previous(final Marker previousMarker) { + MathUtils.checkNotNull(previousMarker); + this.previous = previousMarker; + return this; + } + + /** + * Sets the next marker. + * + * @param nextMarker the next marker to the current marker in the array + * of markers + * @return this instance + */ + private Marker next(final Marker nextMarker) { + MathUtils.checkNotNull(nextMarker); + this.next = nextMarker; + return this; + } + + /** + * Sets the index of the marker. + * + * @param indexOfMarker the array index of the marker in marker array + * @return this instance + */ + private Marker index(final int indexOfMarker) { + this.index = indexOfMarker; + return this; + } + + /** + * Update desired Position with increment. + */ + private void updateDesiredPosition() { + desiredMarkerPosition += desiredMarkerIncrement; + } + + /** + * Increment Position by d. + * + * @param d a delta value to increment + */ + private void incrementPosition(final int d) { + intMarkerPosition += d; + } + + /** + * Difference between desired and actual position + * + * @return difference between desired and actual position + */ + private double difference() { + return desiredMarkerPosition - intMarkerPosition; + } + + /** + * Estimate the quantile for the current marker. + * + * @return estimated quantile + */ + private double estimate() { + final double di = difference(); + final boolean isNextHigher = + next.intMarkerPosition - intMarkerPosition > 1; + final boolean isPreviousLower = + previous.intMarkerPosition - intMarkerPosition < -1; + + if (di >= 1 && isNextHigher || di <= -1 && isPreviousLower) { + final int d = di >= 0 ? 1 : -1; + final double[] xval = + new double[] { previous.intMarkerPosition, + intMarkerPosition, next.intMarkerPosition }; + final double[] yval = + new double[] { previous.markerHeight, markerHeight, + next.markerHeight }; + final double xD = intMarkerPosition + d; + + UnivariateFunction univariateFunction = + nonLinear.interpolate(xval, yval); + markerHeight = univariateFunction.value(xD); + + // If parabolic estimate is bad then turn linear + if (isEstimateBad(yval, markerHeight)) { + int delta = xD - xval[1] > 0 ? 1 : -1; + final double[] xBad = + new double[] { xval[1], xval[1 + delta] }; + final double[] yBad = + new double[] { yval[1], yval[1 + delta] }; + MathArrays.sortInPlace(xBad, yBad);// since d can be +/- 1 + univariateFunction = linear.interpolate(xBad, yBad); + markerHeight = univariateFunction.value(xD); + } + incrementPosition(d); + } + return markerHeight; + } + + /** + * Check if parabolic/nonlinear estimate is bad by checking if the + * ordinate found is beyond the y[0] and y[2]. + * + * @param y the array to get the bounds + * @param yD the estimate + * @return true if yD is a bad estimate + */ + private boolean isEstimateBad(final double[] y, final double yD) { + return yD <= y[0] || yD >= y[2]; + } + + /** + * {@inheritDoc}<i>This equals method checks for marker attributes and + * as well checks if navigation pointers (next and previous) are the same + * between this and passed in object</i> + * + * @param o Other object + * @return true if this equals passed in other object o + */ + @Override + public boolean equals(Object o) { + boolean result = false; + if (this == o) { + result = true; + } else if (o != null && o instanceof Marker) { + Marker that = (Marker) o; + + result = Double.compare(markerHeight, that.markerHeight) == 0; + result = + result && + Double.compare(intMarkerPosition, + that.intMarkerPosition) == 0; + result = + result && + Double.compare(desiredMarkerPosition, + that.desiredMarkerPosition) == 0; + result = + result && + Double.compare(desiredMarkerIncrement, + that.desiredMarkerIncrement) == 0; + + result = result && next.index == that.next.index; + result = result && previous.index == that.previous.index; + } + return result; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(new double[] {markerHeight, intMarkerPosition, + desiredMarkerIncrement, desiredMarkerPosition, previous.index, next.index}); + } + + /** + * Read Object to deserialize. + * + * @param anInstream Stream Object data + * @throws IOException thrown for IO Errors + * @throws ClassNotFoundException thrown for class not being found + */ + private void readObject(ObjectInputStream anInstream) + throws ClassNotFoundException, IOException { + anInstream.defaultReadObject(); + previous=next=this; + linear = new LinearInterpolator(); + } + + /** + * Clone this instance. + * + * @return cloned marker + */ + @Override + public Object clone() { + return new Marker(markerHeight, desiredMarkerPosition, + desiredMarkerIncrement, intMarkerPosition); + } + + /** + * {@inheritDoc} + */ + @Override + public String toString() { + return String.format( + "index=%.0f,n=%.0f,np=%.2f,q=%.2f,dn=%.2f,prev=%d,next=%d", + (double) index, Precision.round(intMarkerPosition, 0), + Precision.round(desiredMarkerPosition, 2), + Precision.round(markerHeight, 2), + Precision.round(desiredMarkerIncrement, 2), previous.index, + next.index); + } + } + + /** + * A simple fixed capacity list that has an upper bound to growth. + * Once its capacity is reached, {@code add} is a no-op, returning + * {@code false}. + * + * @param <E> + */ + private static class FixedCapacityList<E> extends ArrayList<E> implements + Serializable { + /** + * Serialization Version Id + */ + private static final long serialVersionUID = 2283952083075725479L; + /** + * Capacity of the list + */ + private final int capacity; + + /** + * This constructor constructs the list with given capacity and as well + * as stores the capacity + * + * @param fixedCapacity the capacity to be fixed for this list + */ + FixedCapacityList(final int fixedCapacity) { + super(fixedCapacity); + this.capacity = fixedCapacity; + } + + /** + * {@inheritDoc} In addition it checks if the {@link #size()} returns a + * size that is within capacity and if true it adds; otherwise the list + * contents are unchanged and {@code false} is returned. + * + * @return true if addition is successful and false otherwise + */ + @Override + public boolean add(final E e) { + return size() < capacity ? super.add(e) : false; + } + + /** + * {@inheritDoc} In addition it checks if the sum of Collection size and + * this instance's {@link #size()} returns a value that is within + * capacity and if true it adds the collection; otherwise the list + * contents are unchanged and {@code false} is returned. + * + * @return true if addition is successful and false otherwise + */ + @Override + public boolean addAll(Collection<? extends E> collection) { + boolean isCollectionLess = + collection != null && + collection.size() + size() <= capacity; + return isCollectionLess ? super.addAll(collection) : false; + } + } + + /** + * A creation method to build Markers + * + * @param initialFive list of initial five elements + * @param p the quantile desired + * @return an instance of PSquareMarkers + */ + public static PSquareMarkers newMarkers(final List<Double> initialFive, + final double p) { + return new Markers(initialFive, p); + } + + /** + * An interface that encapsulates abstractions of the + * P-square algorithm markers as is explained in the original works. This + * interface is exposed with protected access to help in testability. + */ + protected interface PSquareMarkers extends Cloneable { + /** + * Returns Percentile value computed thus far. + * + * @return percentile + */ + double getPercentileValue(); + + /** + * A clone function to clone the current instance. It's created as an + * interface method as well for convenience though Cloneable is just a + * marker interface. + * + * @return clone of this instance + */ + Object clone(); + + /** + * Returns the marker height (or percentile) of a given marker index. + * + * @param markerIndex is the index of marker in the marker array + * @return percentile value of the marker index passed + * @throws OutOfRangeException in case the index is not within [1-5] + */ + double height(final int markerIndex); + + /** + * Process a data point by moving the marker heights based on estimator. + * + * @param inputDataPoint is the data point passed + * @return computed percentile + */ + double processDataPoint(final double inputDataPoint); + + /** + * An Estimate of the percentile value of a given Marker + * + * @param index the marker's index in the array of markers + * @return percentile estimate + * @throws OutOfRangeException in case if index is not within [1-5] + */ + double estimate(final int index); + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Percentile.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Percentile.java new file mode 100644 index 0000000..bba9e7c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/Percentile.java @@ -0,0 +1,1072 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.rank; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.BitSet; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MathUnsupportedOperationException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.AbstractUnivariateStatistic; +import org.apache.commons.math3.stat.ranking.NaNStrategy; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.KthSelector; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; +import org.apache.commons.math3.util.MedianOf3PivotingStrategy; +import org.apache.commons.math3.util.PivotingStrategyInterface; +import org.apache.commons.math3.util.Precision; + +/** + * Provides percentile computation. + * <p> + * There are several commonly used methods for estimating percentiles (a.k.a. + * quantiles) based on sample data. For large samples, the different methods + * agree closely, but when sample sizes are small, different methods will give + * significantly different results. The algorithm implemented here works as follows: + * <ol> + * <li>Let <code>n</code> be the length of the (sorted) array and + * <code>0 < p <= 100</code> be the desired percentile.</li> + * <li>If <code> n = 1 </code> return the unique array element (regardless of + * the value of <code>p</code>); otherwise </li> + * <li>Compute the estimated percentile position + * <code> pos = p * (n + 1) / 100</code> and the difference, <code>d</code> + * between <code>pos</code> and <code>floor(pos)</code> (i.e. the fractional + * part of <code>pos</code>).</li> + * <li> If <code>pos < 1</code> return the smallest element in the array.</li> + * <li> Else if <code>pos >= n</code> return the largest element in the array.</li> + * <li> Else let <code>lower</code> be the element in position + * <code>floor(pos)</code> in the array and let <code>upper</code> be the + * next element in the array. Return <code>lower + d * (upper - lower)</code> + * </li> + * </ol></p> + * <p> + * To compute percentiles, the data must be at least partially ordered. Input + * arrays are copied and recursively partitioned using an ordering definition. + * The ordering used by <code>Arrays.sort(double[])</code> is the one determined + * by {@link java.lang.Double#compareTo(Double)}. This ordering makes + * <code>Double.NaN</code> larger than any other value (including + * <code>Double.POSITIVE_INFINITY</code>). Therefore, for example, the median + * (50th percentile) of + * <code>{0, 1, 2, 3, 4, Double.NaN}</code> evaluates to <code>2.5.</code></p> + * <p> + * Since percentile estimation usually involves interpolation between array + * elements, arrays containing <code>NaN</code> or infinite values will often + * result in <code>NaN</code> or infinite values returned.</p> + * <p> + * Further, to include different estimation types such as R1, R2 as mentioned in + * <a href="http://en.wikipedia.org/wiki/Quantile">Quantile page(wikipedia)</a>, + * a type specific NaN handling strategy is used to closely match with the + * typically observed results from popular tools like R(R1-R9), Excel(R7).</p> + * <p> + * Since 2.2, Percentile uses only selection instead of complete sorting + * and caches selection algorithm state between calls to the various + * {@code evaluate} methods. This greatly improves efficiency, both for a single + * percentile and multiple percentile computations. To maximize performance when + * multiple percentiles are computed based on the same data, users should set the + * data array once using either one of the {@link #evaluate(double[], double)} or + * {@link #setData(double[])} methods and thereafter {@link #evaluate(double)} + * with just the percentile provided. + * </p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Percentile extends AbstractUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -8091216485095130416L; + + /** Maximum number of partitioning pivots cached (each level double the number of pivots). */ + private static final int MAX_CACHED_LEVELS = 10; + + /** Maximum number of cached pivots in the pivots cached array */ + private static final int PIVOTS_HEAP_LENGTH = 0x1 << MAX_CACHED_LEVELS - 1; + + /** Default KthSelector used with default pivoting strategy */ + private final KthSelector kthSelector; + + /** Any of the {@link EstimationType}s such as {@link EstimationType#LEGACY CM} can be used. */ + private final EstimationType estimationType; + + /** NaN Handling of the input as defined by {@link NaNStrategy} */ + private final NaNStrategy nanStrategy; + + /** Determines what percentile is computed when evaluate() is activated + * with no quantile argument */ + private double quantile; + + /** Cached pivots. */ + private int[] cachedPivots; + + /** + * Constructs a Percentile with the following defaults. + * <ul> + * <li>default quantile: 50.0, can be reset with {@link #setQuantile(double)}</li> + * <li>default estimation type: {@link EstimationType#LEGACY}, + * can be reset with {@link #withEstimationType(EstimationType)}</li> + * <li>default NaN strategy: {@link NaNStrategy#REMOVED}, + * can be reset with {@link #withNaNStrategy(NaNStrategy)}</li> + * <li>a KthSelector that makes use of {@link MedianOf3PivotingStrategy}, + * can be reset with {@link #withKthSelector(KthSelector)}</li> + * </ul> + */ + public Percentile() { + // No try-catch or advertised exception here - arg is valid + this(50.0); + } + + /** + * Constructs a Percentile with the specific quantile value and the following + * <ul> + * <li>default method type: {@link EstimationType#LEGACY}</li> + * <li>default NaN strategy: {@link NaNStrategy#REMOVED}</li> + * <li>a Kth Selector : {@link KthSelector}</li> + * </ul> + * @param quantile the quantile + * @throws MathIllegalArgumentException if p is not greater than 0 and less + * than or equal to 100 + */ + public Percentile(final double quantile) throws MathIllegalArgumentException { + this(quantile, EstimationType.LEGACY, NaNStrategy.REMOVED, + new KthSelector(new MedianOf3PivotingStrategy())); + } + + /** + * Copy constructor, creates a new {@code Percentile} identical + * to the {@code original} + * + * @param original the {@code Percentile} instance to copy + * @throws NullArgumentException if original is null + */ + public Percentile(final Percentile original) throws NullArgumentException { + + MathUtils.checkNotNull(original); + estimationType = original.getEstimationType(); + nanStrategy = original.getNaNStrategy(); + kthSelector = original.getKthSelector(); + + setData(original.getDataRef()); + if (original.cachedPivots != null) { + System.arraycopy(original.cachedPivots, 0, cachedPivots, 0, original.cachedPivots.length); + } + setQuantile(original.quantile); + + } + + /** + * Constructs a Percentile with the specific quantile value, + * {@link EstimationType}, {@link NaNStrategy} and {@link KthSelector}. + * + * @param quantile the quantile to be computed + * @param estimationType one of the percentile {@link EstimationType estimation types} + * @param nanStrategy one of {@link NaNStrategy} to handle with NaNs + * @param kthSelector a {@link KthSelector} to use for pivoting during search + * @throws MathIllegalArgumentException if p is not within (0,100] + * @throws NullArgumentException if type or NaNStrategy passed is null + */ + protected Percentile(final double quantile, + final EstimationType estimationType, + final NaNStrategy nanStrategy, + final KthSelector kthSelector) + throws MathIllegalArgumentException { + setQuantile(quantile); + cachedPivots = null; + MathUtils.checkNotNull(estimationType); + MathUtils.checkNotNull(nanStrategy); + MathUtils.checkNotNull(kthSelector); + this.estimationType = estimationType; + this.nanStrategy = nanStrategy; + this.kthSelector = kthSelector; + } + + /** {@inheritDoc} */ + @Override + public void setData(final double[] values) { + if (values == null) { + cachedPivots = null; + } else { + cachedPivots = new int[PIVOTS_HEAP_LENGTH]; + Arrays.fill(cachedPivots, -1); + } + super.setData(values); + } + + /** {@inheritDoc} */ + @Override + public void setData(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + if (values == null) { + cachedPivots = null; + } else { + cachedPivots = new int[PIVOTS_HEAP_LENGTH]; + Arrays.fill(cachedPivots, -1); + } + super.setData(values, begin, length); + } + + /** + * Returns the result of evaluating the statistic over the stored data. + * <p> + * The stored array is the one which was set by previous calls to + * {@link #setData(double[])} + * </p> + * @param p the percentile value to compute + * @return the value of the statistic applied to the stored data + * @throws MathIllegalArgumentException if p is not a valid quantile value + * (p must be greater than 0 and less than or equal to 100) + */ + public double evaluate(final double p) throws MathIllegalArgumentException { + return evaluate(getDataRef(), p); + } + + /** + * Returns an estimate of the <code>p</code>th percentile of the values + * in the <code>values</code> array. + * <p> + * Calls to this method do not modify the internal <code>quantile</code> + * state of this statistic.</p> + * <p> + * <ul> + * <li>Returns <code>Double.NaN</code> if <code>values</code> has length + * <code>0</code></li> + * <li>Returns (for any value of <code>p</code>) <code>values[0]</code> + * if <code>values</code> has length <code>1</code></li> + * <li>Throws <code>MathIllegalArgumentException</code> if <code>values</code> + * is null or p is not a valid quantile value (p must be greater than 0 + * and less than or equal to 100) </li> + * </ul></p> + * <p> + * See {@link Percentile} for a description of the percentile estimation + * algorithm used.</p> + * + * @param values input array of values + * @param p the percentile value to compute + * @return the percentile value or Double.NaN if the array is empty + * @throws MathIllegalArgumentException if <code>values</code> is null + * or p is invalid + */ + public double evaluate(final double[] values, final double p) + throws MathIllegalArgumentException { + test(values, 0, 0); + return evaluate(values, 0, values.length, p); + } + + /** + * Returns an estimate of the <code>quantile</code>th percentile of the + * designated values in the <code>values</code> array. The quantile + * estimated is determined by the <code>quantile</code> property. + * <p> + * <ul> + * <li>Returns <code>Double.NaN</code> if <code>length = 0</code></li> + * <li>Returns (for any value of <code>quantile</code>) + * <code>values[begin]</code> if <code>length = 1 </code></li> + * <li>Throws <code>MathIllegalArgumentException</code> if <code>values</code> + * is null, or <code>start</code> or <code>length</code> is invalid</li> + * </ul></p> + * <p> + * See {@link Percentile} for a description of the percentile estimation + * algorithm used.</p> + * + * @param values the input array + * @param start index of the first array element to include + * @param length the number of elements to include + * @return the percentile value + * @throws MathIllegalArgumentException if the parameters are not valid + * + */ + @Override + public double evaluate(final double[] values, final int start, final int length) + throws MathIllegalArgumentException { + return evaluate(values, start, length, quantile); + } + + /** + * Returns an estimate of the <code>p</code>th percentile of the values + * in the <code>values</code> array, starting with the element in (0-based) + * position <code>begin</code> in the array and including <code>length</code> + * values. + * <p> + * Calls to this method do not modify the internal <code>quantile</code> + * state of this statistic.</p> + * <p> + * <ul> + * <li>Returns <code>Double.NaN</code> if <code>length = 0</code></li> + * <li>Returns (for any value of <code>p</code>) <code>values[begin]</code> + * if <code>length = 1 </code></li> + * <li>Throws <code>MathIllegalArgumentException</code> if <code>values</code> + * is null , <code>begin</code> or <code>length</code> is invalid, or + * <code>p</code> is not a valid quantile value (p must be greater than 0 + * and less than or equal to 100)</li> + * </ul></p> + * <p> + * See {@link Percentile} for a description of the percentile estimation + * algorithm used.</p> + * + * @param values array of input values + * @param p the percentile to compute + * @param begin the first (0-based) element to include in the computation + * @param length the number of array elements to include + * @return the percentile value + * @throws MathIllegalArgumentException if the parameters are not valid or the + * input array is null + */ + public double evaluate(final double[] values, final int begin, + final int length, final double p) + throws MathIllegalArgumentException { + + test(values, begin, length); + if (p > 100 || p <= 0) { + throw new OutOfRangeException( + LocalizedFormats.OUT_OF_BOUNDS_QUANTILE_VALUE, p, 0, 100); + } + if (length == 0) { + return Double.NaN; + } + if (length == 1) { + return values[begin]; // always return single value for n = 1 + } + + final double[] work = getWorkArray(values, begin, length); + final int[] pivotsHeap = getPivots(values); + return work.length == 0 ? Double.NaN : + estimationType.evaluate(work, pivotsHeap, p, kthSelector); + } + + /** Select a pivot index as the median of three + * <p> + * <b>Note:</b> With the effect of allowing {@link KthSelector} to be set on + * {@link Percentile} instances(thus indirectly {@link PivotingStrategy}) + * this method wont take effect any more and hence is unsupported. + * @param work data array + * @param begin index of the first element of the slice + * @param end index after the last element of the slice + * @return the index of the median element chosen between the + * first, the middle and the last element of the array slice + * @deprecated Please refrain from using this method (as it wont take effect) + * and instead use {@link Percentile#withKthSelector(newKthSelector)} if + * required. + * + */ + @Deprecated + int medianOf3(final double[] work, final int begin, final int end) { + return new MedianOf3PivotingStrategy().pivotIndex(work, begin, end); + //throw new MathUnsupportedOperationException(); + } + + /** + * Returns the value of the quantile field (determines what percentile is + * computed when evaluate() is called with no quantile argument). + * + * @return quantile set while construction or {@link #setQuantile(double)} + */ + public double getQuantile() { + return quantile; + } + + /** + * Sets the value of the quantile field (determines what percentile is + * computed when evaluate() is called with no quantile argument). + * + * @param p a value between 0 < p <= 100 + * @throws MathIllegalArgumentException if p is not greater than 0 and less + * than or equal to 100 + */ + public void setQuantile(final double p) throws MathIllegalArgumentException { + if (p <= 0 || p > 100) { + throw new OutOfRangeException( + LocalizedFormats.OUT_OF_BOUNDS_QUANTILE_VALUE, p, 0, 100); + } + quantile = p; + } + + /** + * {@inheritDoc} + */ + @Override + public Percentile copy() { + return new Percentile(this); + } + + /** + * Copies source to dest. + * @param source Percentile to copy + * @param dest Percentile to copy to + * @exception MathUnsupportedOperationException always thrown since 3.4 + * @deprecated as of 3.4 this method does not work anymore, as it fails to + * copy internal states between instances configured with different + * {@link EstimationType estimation type}, {@link NaNStrategy NaN handling strategies} + * and {@link KthSelector kthSelector}, it therefore always + * throw {@link MathUnsupportedOperationException} + */ + @Deprecated + public static void copy(final Percentile source, final Percentile dest) + throws MathUnsupportedOperationException { + throw new MathUnsupportedOperationException(); + } + + /** + * Get the work array to operate. Makes use of prior {@code storedData} if + * it exists or else do a check on NaNs and copy a subset of the array + * defined by begin and length parameters. The set {@link #nanStrategy} will + * be used to either retain/remove/replace any NaNs present before returning + * the resultant array. + * + * @param values the array of numbers + * @param begin index to start reading the array + * @param length the length of array to be read from the begin index + * @return work array sliced from values in the range [begin,begin+length) + * @throws MathIllegalArgumentException if values or indices are invalid + */ + protected double[] getWorkArray(final double[] values, final int begin, final int length) { + final double[] work; + if (values == getDataRef()) { + work = getDataRef(); + } else { + switch (nanStrategy) { + case MAXIMAL:// Replace NaNs with +INFs + work = replaceAndSlice(values, begin, length, Double.NaN, Double.POSITIVE_INFINITY); + break; + case MINIMAL:// Replace NaNs with -INFs + work = replaceAndSlice(values, begin, length, Double.NaN, Double.NEGATIVE_INFINITY); + break; + case REMOVED:// Drop NaNs from data + work = removeAndSlice(values, begin, length, Double.NaN); + break; + case FAILED:// just throw exception as NaN is un-acceptable + work = copyOf(values, begin, length); + MathArrays.checkNotNaN(work); + break; + default: //FIXED + work = copyOf(values,begin,length); + break; + } + } + return work; + } + + /** + * Make a copy of the array for the slice defined by array part from + * [begin, begin+length) + * @param values the input array + * @param begin start index of the array to include + * @param length number of elements to include from begin + * @return copy of a slice of the original array + */ + private static double[] copyOf(final double[] values, final int begin, final int length) { + MathArrays.verifyValues(values, begin, length); + return MathArrays.copyOfRange(values, begin, begin + length); + } + + /** + * Replace every occurrence of a given value with a replacement value in a + * copied slice of array defined by array part from [begin, begin+length). + * @param values the input array + * @param begin start index of the array to include + * @param length number of elements to include from begin + * @param original the value to be replaced with + * @param replacement the value to be used for replacement + * @return the copy of sliced array with replaced values + */ + private static double[] replaceAndSlice(final double[] values, + final int begin, final int length, + final double original, + final double replacement) { + final double[] temp = copyOf(values, begin, length); + for(int i = 0; i < length; i++) { + temp[i] = Precision.equalsIncludingNaN(original, temp[i]) ? + replacement : temp[i]; + } + return temp; + } + + /** + * Remove the occurrence of a given value in a copied slice of array + * defined by the array part from [begin, begin+length). + * @param values the input array + * @param begin start index of the array to include + * @param length number of elements to include from begin + * @param removedValue the value to be removed from the sliced array + * @return the copy of the sliced array after removing the removedValue + */ + private static double[] removeAndSlice(final double[] values, + final int begin, final int length, + final double removedValue) { + MathArrays.verifyValues(values, begin, length); + final double[] temp; + //BitSet(length) to indicate where the removedValue is located + final BitSet bits = new BitSet(length); + for (int i = begin; i < begin+length; i++) { + if (Precision.equalsIncludingNaN(removedValue, values[i])) { + bits.set(i - begin); + } + } + //Check if empty then create a new copy + if (bits.isEmpty()) { + temp = copyOf(values, begin, length); // Nothing removed, just copy + } else if(bits.cardinality() == length){ + temp = new double[0]; // All removed, just empty + }else { // Some removable, so new + temp = new double[length - bits.cardinality()]; + int start = begin; //start index from source array (i.e values) + int dest = 0; //dest index in destination array(i.e temp) + int nextOne = -1; //nextOne is the index of bit set of next one + int bitSetPtr = 0; //bitSetPtr is start index pointer of bitset + while ((nextOne = bits.nextSetBit(bitSetPtr)) != -1) { + final int lengthToCopy = nextOne - bitSetPtr; + System.arraycopy(values, start, temp, dest, lengthToCopy); + dest += lengthToCopy; + start = begin + (bitSetPtr = bits.nextClearBit(nextOne)); + } + //Copy any residue past start index till begin+length + if (start < begin + length) { + System.arraycopy(values,start,temp,dest,begin + length - start); + } + } + return temp; + } + + /** + * Get pivots which is either cached or a newly created one + * + * @param values array containing the input numbers + * @return cached pivots or a newly created one + */ + private int[] getPivots(final double[] values) { + final int[] pivotsHeap; + if (values == getDataRef()) { + pivotsHeap = cachedPivots; + } else { + pivotsHeap = new int[PIVOTS_HEAP_LENGTH]; + Arrays.fill(pivotsHeap, -1); + } + return pivotsHeap; + } + + /** + * Get the estimation {@link EstimationType type} used for computation. + * + * @return the {@code estimationType} set + */ + public EstimationType getEstimationType() { + return estimationType; + } + + /** + * Build a new instance similar to the current one except for the + * {@link EstimationType estimation type}. + * <p> + * This method is intended to be used as part of a fluent-type builder + * pattern. Building finely tune instances should be done as follows: + * </p> + * <pre> + * Percentile customized = new Percentile(quantile). + * withEstimationType(estimationType). + * withNaNStrategy(nanStrategy). + * withKthSelector(kthSelector); + * </pre> + * <p> + * If any of the {@code withXxx} method is omitted, the default value for + * the corresponding customization parameter will be used. + * </p> + * @param newEstimationType estimation type for the new instance + * @return a new instance, with changed estimation type + * @throws NullArgumentException when newEstimationType is null + */ + public Percentile withEstimationType(final EstimationType newEstimationType) { + return new Percentile(quantile, newEstimationType, nanStrategy, kthSelector); + } + + /** + * Get the {@link NaNStrategy NaN Handling} strategy used for computation. + * @return {@code NaN Handling} strategy set during construction + */ + public NaNStrategy getNaNStrategy() { + return nanStrategy; + } + + /** + * Build a new instance similar to the current one except for the + * {@link NaNStrategy NaN handling} strategy. + * <p> + * This method is intended to be used as part of a fluent-type builder + * pattern. Building finely tune instances should be done as follows: + * </p> + * <pre> + * Percentile customized = new Percentile(quantile). + * withEstimationType(estimationType). + * withNaNStrategy(nanStrategy). + * withKthSelector(kthSelector); + * </pre> + * <p> + * If any of the {@code withXxx} method is omitted, the default value for + * the corresponding customization parameter will be used. + * </p> + * @param newNaNStrategy NaN strategy for the new instance + * @return a new instance, with changed NaN handling strategy + * @throws NullArgumentException when newNaNStrategy is null + */ + public Percentile withNaNStrategy(final NaNStrategy newNaNStrategy) { + return new Percentile(quantile, estimationType, newNaNStrategy, kthSelector); + } + + /** + * Get the {@link KthSelector kthSelector} used for computation. + * @return the {@code kthSelector} set + */ + public KthSelector getKthSelector() { + return kthSelector; + } + + /** + * Get the {@link PivotingStrategyInterface} used in KthSelector for computation. + * @return the pivoting strategy set + */ + public PivotingStrategyInterface getPivotingStrategy() { + return kthSelector.getPivotingStrategy(); + } + + /** + * Build a new instance similar to the current one except for the + * {@link KthSelector kthSelector} instance specifically set. + * <p> + * This method is intended to be used as part of a fluent-type builder + * pattern. Building finely tune instances should be done as follows: + * </p> + * <pre> + * Percentile customized = new Percentile(quantile). + * withEstimationType(estimationType). + * withNaNStrategy(nanStrategy). + * withKthSelector(newKthSelector); + * </pre> + * <p> + * If any of the {@code withXxx} method is omitted, the default value for + * the corresponding customization parameter will be used. + * </p> + * @param newKthSelector KthSelector for the new instance + * @return a new instance, with changed KthSelector + * @throws NullArgumentException when newKthSelector is null + */ + public Percentile withKthSelector(final KthSelector newKthSelector) { + return new Percentile(quantile, estimationType, nanStrategy, + newKthSelector); + } + + /** + * An enum for various estimation strategies of a percentile referred in + * <a href="http://en.wikipedia.org/wiki/Quantile">wikipedia on quantile</a> + * with the names of enum matching those of types mentioned in + * wikipedia. + * <p> + * Each enum corresponding to the specific type of estimation in wikipedia + * implements the respective formulae that specializes in the below aspects + * <ul> + * <li>An <b>index method</b> to calculate approximate index of the + * estimate</li> + * <li>An <b>estimate method</b> to estimate a value found at the earlier + * computed index</li> + * <li>A <b> minLimit</b> on the quantile for which first element of sorted + * input is returned as an estimate </li> + * <li>A <b> maxLimit</b> on the quantile for which last element of sorted + * input is returned as an estimate </li> + * </ul> + * <p> + * Users can now create {@link Percentile} by explicitly passing this enum; + * such as by invoking {@link Percentile#withEstimationType(EstimationType)} + * <p> + * References: + * <ol> + * <li> + * <a href="http://en.wikipedia.org/wiki/Quantile">Wikipedia on quantile</a> + * </li> + * <li> + * <a href="https://www.amherst.edu/media/view/129116/.../Sample+Quantiles.pdf"> + * Hyndman, R. J. and Fan, Y. (1996) Sample quantiles in statistical + * packages, American Statistician 50, 361–365</a> </li> + * <li> + * <a href="http://stat.ethz.ch/R-manual/R-devel/library/stats/html/quantile.html"> + * R-Manual </a></li> + * </ol> + * + */ + public enum EstimationType { + /** + * This is the default type used in the {@link Percentile}.This method + * has the following formulae for index and estimates<br> + * \( \begin{align} + * &index = (N+1)p\ \\ + * &estimate = x_{\lceil h\,-\,1/2 \rceil} \\ + * &minLimit = 0 \\ + * &maxLimit = 1 \\ + * \end{align}\) + */ + LEGACY("Legacy Apache Commons Math") { + /** + * {@inheritDoc}.This method in particular makes use of existing + * Apache Commons Math style of picking up the index. + */ + @Override + protected double index(final double p, final int length) { + final double minLimit = 0d; + final double maxLimit = 1d; + return Double.compare(p, minLimit) == 0 ? 0 : + Double.compare(p, maxLimit) == 0 ? + length : p * (length + 1); + } + }, + /** + * The method R_1 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= Np + 1/2\, \\ + * &estimate= x_{\lceil h\,-\,1/2 \rceil} \\ + * &minLimit = 0 \\ + * \end{align}\) + */ + R_1("R-1") { + + @Override + protected double index(final double p, final int length) { + final double minLimit = 0d; + return Double.compare(p, minLimit) == 0 ? 0 : length * p + 0.5; + } + + /** + * {@inheritDoc}This method in particular for R_1 uses ceil(pos-0.5) + */ + @Override + protected double estimate(final double[] values, + final int[] pivotsHeap, final double pos, + final int length, final KthSelector selector) { + return super.estimate(values, pivotsHeap, FastMath.ceil(pos - 0.5), length, selector); + } + + }, + /** + * The method R_2 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= Np + 1/2\, \\ + * &estimate=\frac{x_{\lceil h\,-\,1/2 \rceil} + + * x_{\lfloor h\,+\,1/2 \rfloor}}{2} \\ + * &minLimit = 0 \\ + * &maxLimit = 1 \\ + * \end{align}\) + */ + R_2("R-2") { + + @Override + protected double index(final double p, final int length) { + final double minLimit = 0d; + final double maxLimit = 1d; + return Double.compare(p, maxLimit) == 0 ? length : + Double.compare(p, minLimit) == 0 ? 0 : length * p + 0.5; + } + + /** + * {@inheritDoc}This method in particular for R_2 averages the + * values at ceil(p+0.5) and floor(p-0.5). + */ + @Override + protected double estimate(final double[] values, + final int[] pivotsHeap, final double pos, + final int length, final KthSelector selector) { + final double low = + super.estimate(values, pivotsHeap, FastMath.ceil(pos - 0.5), length, selector); + final double high = + super.estimate(values, pivotsHeap,FastMath.floor(pos + 0.5), length, selector); + return (low + high) / 2; + } + + }, + /** + * The method R_3 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= Np \\ + * &estimate= x_{\lfloor h \rceil}\, \\ + * &minLimit = 0.5/N \\ + * \end{align}\) + */ + R_3("R-3") { + @Override + protected double index(final double p, final int length) { + final double minLimit = 1d/2 / length; + return Double.compare(p, minLimit) <= 0 ? + 0 : FastMath.rint(length * p); + } + + }, + /** + * The method R_4 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= Np\, \\ + * &estimate= x_{\lfloor h \rfloor} + (h - + * \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = 1/N \\ + * &maxLimit = 1 \\ + * \end{align}\) + */ + R_4("R-4") { + @Override + protected double index(final double p, final int length) { + final double minLimit = 1d / length; + final double maxLimit = 1d; + return Double.compare(p, minLimit) < 0 ? 0 : + Double.compare(p, maxLimit) == 0 ? length : length * p; + } + + }, + /** + * The method R_5 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= Np + 1/2\\ + * &estimate= x_{\lfloor h \rfloor} + (h - + * \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = 0.5/N \\ + * &maxLimit = (N-0.5)/N + * \end{align}\) + */ + R_5("R-5"){ + + @Override + protected double index(final double p, final int length) { + final double minLimit = 1d/2 / length; + final double maxLimit = (length - 0.5) / length; + return Double.compare(p, minLimit) < 0 ? 0 : + Double.compare(p, maxLimit) >= 0 ? + length : length * p + 0.5; + } + }, + /** + * The method R_6 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index= (N + 1)p \\ + * &estimate= x_{\lfloor h \rfloor} + (h - + * \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = 1/(N+1) \\ + * &maxLimit = N/(N+1) \\ + * \end{align}\) + * <p> + * <b>Note:</b> This method computes the index in a manner very close to + * the default Commons Math Percentile existing implementation. However + * the difference to be noted is in picking up the limits with which + * first element (p<1(N+1)) and last elements (p>N/(N+1))are done. + * While in default case; these are done with p=0 and p=1 respectively. + */ + R_6("R-6"){ + + @Override + protected double index(final double p, final int length) { + final double minLimit = 1d / (length + 1); + final double maxLimit = 1d * length / (length + 1); + return Double.compare(p, minLimit) < 0 ? 0 : + Double.compare(p, maxLimit) >= 0 ? + length : (length + 1) * p; + } + }, + + /** + * The method R_7 implements Microsoft Excel style computation has the + * following formulae for index and estimates.<br> + * \( \begin{align} + * &index = (N-1)p + 1 \\ + * &estimate = x_{\lfloor h \rfloor} + (h - + * \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = 0 \\ + * &maxLimit = 1 \\ + * \end{align}\) + */ + R_7("R-7") { + @Override + protected double index(final double p, final int length) { + final double minLimit = 0d; + final double maxLimit = 1d; + return Double.compare(p, minLimit) == 0 ? 0 : + Double.compare(p, maxLimit) == 0 ? + length : 1 + (length - 1) * p; + } + + }, + + /** + * The method R_8 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index = (N + 1/3)p + 1/3 \\ + * &estimate = x_{\lfloor h \rfloor} + (h - + \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = (2/3)/(N+1/3) \\ + * &maxLimit = (N-1/3)/(N+1/3) \\ + * \end{align}\) + * <p> + * As per Ref [2,3] this approach is most recommended as it provides + * an approximate median-unbiased estimate regardless of distribution. + */ + R_8("R-8") { + @Override + protected double index(final double p, final int length) { + final double minLimit = 2 * (1d / 3) / (length + 1d / 3); + final double maxLimit = + (length - 1d / 3) / (length + 1d / 3); + return Double.compare(p, minLimit) < 0 ? 0 : + Double.compare(p, maxLimit) >= 0 ? length : + (length + 1d / 3) * p + 1d / 3; + } + }, + + /** + * The method R_9 has the following formulae for index and estimates<br> + * \( \begin{align} + * &index = (N + 1/4)p + 3/8\\ + * &estimate = x_{\lfloor h \rfloor} + (h - + \lfloor h \rfloor) (x_{\lfloor h \rfloor + 1} - x_{\lfloor h + * \rfloor}) \\ + * &minLimit = (5/8)/(N+1/4) \\ + * &maxLimit = (N-3/8)/(N+1/4) \\ + * \end{align}\) + */ + R_9("R-9") { + @Override + protected double index(final double p, final int length) { + final double minLimit = 5d/8 / (length + 0.25); + final double maxLimit = (length - 3d/8) / (length + 0.25); + return Double.compare(p, minLimit) < 0 ? 0 : + Double.compare(p, maxLimit) >= 0 ? length : + (length + 0.25) * p + 3d/8; + } + + }, + ; + + /** Simple name such as R-1, R-2 corresponding to those in wikipedia. */ + private final String name; + + /** + * Constructor + * + * @param type name of estimation type as per wikipedia + */ + EstimationType(final String type) { + this.name = type; + } + + /** + * Finds the index of array that can be used as starting index to + * {@link #estimate(double[], int[], double, int, KthSelector) estimate} + * percentile. The calculation of index calculation is specific to each + * {@link EstimationType}. + * + * @param p the p<sup>th</sup> quantile + * @param length the total number of array elements in the work array + * @return a computed real valued index as explained in the wikipedia + */ + protected abstract double index(final double p, final int length); + + /** + * Estimation based on K<sup>th</sup> selection. This may be overridden + * in specific enums to compute slightly different estimations. + * + * @param work array of numbers to be used for finding the percentile + * @param pos indicated positional index prior computed from calling + * {@link #index(double, int)} + * @param pivotsHeap an earlier populated cache if exists; will be used + * @param length size of array considered + * @param selector a {@link KthSelector} used for pivoting during search + * @return estimated percentile + */ + protected double estimate(final double[] work, final int[] pivotsHeap, + final double pos, final int length, + final KthSelector selector) { + + final double fpos = FastMath.floor(pos); + final int intPos = (int) fpos; + final double dif = pos - fpos; + + if (pos < 1) { + return selector.select(work, pivotsHeap, 0); + } + if (pos >= length) { + return selector.select(work, pivotsHeap, length - 1); + } + + final double lower = selector.select(work, pivotsHeap, intPos - 1); + final double upper = selector.select(work, pivotsHeap, intPos); + return lower + dif * (upper - lower); + } + + /** + * Evaluate method to compute the percentile for a given bounded array + * using earlier computed pivots heap.<br> + * This basically calls the {@link #index(double, int) index} and then + * {@link #estimate(double[], int[], double, int, KthSelector) estimate} + * functions to return the estimated percentile value. + * + * @param work array of numbers to be used for finding the percentile + * @param pivotsHeap a prior cached heap which can speed up estimation + * @param p the p<sup>th</sup> quantile to be computed + * @param selector a {@link KthSelector} used for pivoting during search + * @return estimated percentile + * @throws OutOfRangeException if p is out of range + * @throws NullArgumentException if work array is null + */ + protected double evaluate(final double[] work, final int[] pivotsHeap, final double p, + final KthSelector selector) { + MathUtils.checkNotNull(work); + if (p > 100 || p <= 0) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUNDS_QUANTILE_VALUE, + p, 0, 100); + } + return estimate(work, pivotsHeap, index(p/100d, work.length), work.length, selector); + } + + /** + * Evaluate method to compute the percentile for a given bounded array. + * This basically calls the {@link #index(double, int) index} and then + * {@link #estimate(double[], int[], double, int, KthSelector) estimate} + * functions to return the estimated percentile value. Please + * note that this method does not make use of cached pivots. + * + * @param work array of numbers to be used for finding the percentile + * @param p the p<sup>th</sup> quantile to be computed + * @return estimated percentile + * @param selector a {@link KthSelector} used for pivoting during search + * @throws OutOfRangeException if length or p is out of range + * @throws NullArgumentException if work array is null + */ + public double evaluate(final double[] work, final double p, final KthSelector selector) { + return this.evaluate(work, null, p, selector); + } + + /** + * Gets the name of the enum + * + * @return the name + */ + String getName() { + return name; + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/rank/package-info.java b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/package-info.java new file mode 100644 index 0000000..da37b37 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/rank/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Summary statistics based on ranks. + */ +package org.apache.commons.math3.stat.descriptive.rank; diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Product.java b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Product.java new file mode 100644 index 0000000..7d313a5 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Product.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.summary; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.stat.descriptive.WeightedEvaluation; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the product of the available values. + * <p> + * If there are no values in the dataset, then 1 is returned. + * If any of the values are + * <code>NaN</code>, then <code>NaN</code> is returned.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Product extends AbstractStorelessUnivariateStatistic implements Serializable, WeightedEvaluation { + + /** Serializable version identifier */ + private static final long serialVersionUID = 2824226005990582538L; + + /**The number of values that have been added */ + private long n; + + /** + * The current Running Product. + */ + private double value; + + /** + * Create a Product instance + */ + public Product() { + n = 0; + value = 1; + } + + /** + * Copy constructor, creates a new {@code Product} identical + * to the {@code original} + * + * @param original the {@code Product} instance to copy + * @throws NullArgumentException if original is null + */ + public Product(Product original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + value *= d; + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = 1; + n = 0; + } + + /** + * Returns the product of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the product of the values or 1 if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + double product = Double.NaN; + if (test(values, begin, length, true)) { + product = 1.0; + for (int i = begin; i < begin + length; i++) { + product *= values[i]; + } + } + return product; + } + + /** + * <p>Returns the weighted product of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty.</p> + * + * <p>Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li> + * </ul></p> + * + * <p>Uses the formula, <pre> + * weighted product = ∏values[i]<sup>weights[i]</sup> + * </pre> + * that is, the weights are applied as exponents when computing the weighted product.</p> + * + * @param values the input array + * @param weights the weights array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the product of the values or 1 if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, + final int begin, final int length) throws MathIllegalArgumentException { + double product = Double.NaN; + if (test(values, weights, begin, length, true)) { + product = 1.0; + for (int i = begin; i < begin + length; i++) { + product *= FastMath.pow(values[i], weights[i]); + } + } + return product; + } + + /** + * <p>Returns the weighted product of the entries in the input array.</p> + * + * <p>Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * </ul></p> + * + * <p>Uses the formula, <pre> + * weighted product = ∏values[i]<sup>weights[i]</sup> + * </pre> + * that is, the weights are applied as exponents when computing the weighted product.</p> + * + * @param values the input array + * @param weights the weights array + * @return the product of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights) + throws MathIllegalArgumentException { + return evaluate(values, weights, 0, values.length); + } + + + /** + * {@inheritDoc} + */ + @Override + public Product copy() { + Product result = new Product(); + // No try-catch or advertised exception because args are valid + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Product to copy + * @param dest Product to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Product source, Product dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Sum.java b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Sum.java new file mode 100644 index 0000000..e12b6a1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/Sum.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.summary; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + + +/** + * Returns the sum of the available values. + * <p> + * If there are no values in the dataset, then 0 is returned. + * If any of the values are + * <code>NaN</code>, then <code>NaN</code> is returned.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class Sum extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -8231831954703408316L; + + /** */ + private long n; + + /** + * The currently running sum. + */ + private double value; + + /** + * Create a Sum instance + */ + public Sum() { + n = 0; + value = 0; + } + + /** + * Copy constructor, creates a new {@code Sum} identical + * to the {@code original} + * + * @param original the {@code Sum} instance to copy + * @throws NullArgumentException if original is null + */ + public Sum(Sum original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + value += d; + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = 0; + n = 0; + } + + /** + * The sum of the entries in the specified portion of + * the input array, or 0 if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the values or 0 if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + double sum = Double.NaN; + if (test(values, begin, length, true)) { + sum = 0.0; + for (int i = begin; i < begin + length; i++) { + sum += values[i]; + } + } + return sum; + } + + /** + * The weighted sum of the entries in the specified portion of + * the input array, or 0 if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * <li>the start and length arguments do not determine a valid array</li> + * </ul></p> + * <p> + * Uses the formula, <pre> + * weighted sum = Σ(values[i] * weights[i]) + * </pre></p> + * + * @param values the input array + * @param weights the weights array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the values or 0 if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights, + final int begin, final int length) throws MathIllegalArgumentException { + double sum = Double.NaN; + if (test(values, weights, begin, length, true)) { + sum = 0.0; + for (int i = begin; i < begin + length; i++) { + sum += values[i] * weights[i]; + } + } + return sum; + } + + /** + * The weighted sum of the entries in the the input array. + * <p> + * Throws <code>MathIllegalArgumentException</code> if any of the following are true: + * <ul><li>the values array is null</li> + * <li>the weights array is null</li> + * <li>the weights array does not have the same length as the values array</li> + * <li>the weights array contains one or more infinite values</li> + * <li>the weights array contains one or more NaN values</li> + * <li>the weights array contains negative values</li> + * </ul></p> + * <p> + * Uses the formula, <pre> + * weighted sum = Σ(values[i] * weights[i]) + * </pre></p> + * + * @param values the input array + * @param weights the weights array + * @return the sum of the values or Double.NaN if length = 0 + * @throws MathIllegalArgumentException if the parameters are not valid + * @since 2.1 + */ + public double evaluate(final double[] values, final double[] weights) + throws MathIllegalArgumentException { + return evaluate(values, weights, 0, values.length); + } + + /** + * {@inheritDoc} + */ + @Override + public Sum copy() { + Sum result = new Sum(); + // No try-catch or advertised exception because args are valid + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source Sum to copy + * @param dest Sum to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(Sum source, Sum dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfLogs.java b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfLogs.java new file mode 100644 index 0000000..19718af --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfLogs.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.summary; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the sum of the natural logs for this collection of values. + * <p> + * Uses {@link org.apache.commons.math3.util.FastMath#log(double)} to compute the logs. + * Therefore, + * <ul> + * <li>If any of values are < 0, the result is <code>NaN.</code></li> + * <li>If all values are non-negative and less than + * <code>Double.POSITIVE_INFINITY</code>, but at least one value is 0, the + * result is <code>Double.NEGATIVE_INFINITY.</code></li> + * <li>If both <code>Double.POSITIVE_INFINITY</code> and + * <code>Double.NEGATIVE_INFINITY</code> are among the values, the result is + * <code>NaN.</code></li> + * </ul></p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class SumOfLogs extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = -370076995648386763L; + + /**Number of values that have been added */ + private int n; + + /** + * The currently running value + */ + private double value; + + /** + * Create a SumOfLogs instance + */ + public SumOfLogs() { + value = 0d; + n = 0; + } + + /** + * Copy constructor, creates a new {@code SumOfLogs} identical + * to the {@code original} + * + * @param original the {@code SumOfLogs} instance to copy + * @throws NullArgumentException if original is null + */ + public SumOfLogs(SumOfLogs original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + value += FastMath.log(d); + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = 0d; + n = 0; + } + + /** + * Returns the sum of the natural logs of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * <p> + * See {@link SumOfLogs}.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the natural logs of the values or 0 if + * length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values, final int begin, final int length) + throws MathIllegalArgumentException { + double sumLog = Double.NaN; + if (test(values, begin, length, true)) { + sumLog = 0.0; + for (int i = begin; i < begin + length; i++) { + sumLog += FastMath.log(values[i]); + } + } + return sumLog; + } + + /** + * {@inheritDoc} + */ + @Override + public SumOfLogs copy() { + SumOfLogs result = new SumOfLogs(); + // No try-catch or advertised exception here because args are valid + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source SumOfLogs to copy + * @param dest SumOfLogs to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SumOfLogs source, SumOfLogs dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfSquares.java b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfSquares.java new file mode 100644 index 0000000..161d8c8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/SumOfSquares.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.descriptive.summary; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.descriptive.AbstractStorelessUnivariateStatistic; +import org.apache.commons.math3.util.MathUtils; + +/** + * Returns the sum of the squares of the available values. + * <p> + * If there are no values in the dataset, then 0 is returned. + * If any of the values are + * <code>NaN</code>, then <code>NaN</code> is returned.</p> + * <p> + * <strong>Note that this implementation is not synchronized.</strong> If + * multiple threads access an instance of this class concurrently, and at least + * one of the threads invokes the <code>increment()</code> or + * <code>clear()</code> method, it must be synchronized externally.</p> + * + */ +public class SumOfSquares extends AbstractStorelessUnivariateStatistic implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 1460986908574398008L; + + /** */ + private long n; + + /** + * The currently running sumSq + */ + private double value; + + /** + * Create a SumOfSquares instance + */ + public SumOfSquares() { + n = 0; + value = 0; + } + + /** + * Copy constructor, creates a new {@code SumOfSquares} identical + * to the {@code original} + * + * @param original the {@code SumOfSquares} instance to copy + * @throws NullArgumentException if original is null + */ + public SumOfSquares(SumOfSquares original) throws NullArgumentException { + copy(original, this); + } + + /** + * {@inheritDoc} + */ + @Override + public void increment(final double d) { + value += d * d; + n++; + } + + /** + * {@inheritDoc} + */ + @Override + public double getResult() { + return value; + } + + /** + * {@inheritDoc} + */ + public long getN() { + return n; + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + value = 0; + n = 0; + } + + /** + * Returns the sum of the squares of the entries in the specified portion of + * the input array, or <code>Double.NaN</code> if the designated subarray + * is empty. + * <p> + * Throws <code>MathIllegalArgumentException</code> if the array is null.</p> + * + * @param values the input array + * @param begin index of the first array element to include + * @param length the number of elements to include + * @return the sum of the squares of the values or 0 if length = 0 + * @throws MathIllegalArgumentException if the array is null or the array index + * parameters are not valid + */ + @Override + public double evaluate(final double[] values,final int begin, final int length) + throws MathIllegalArgumentException { + double sumSq = Double.NaN; + if (test(values, begin, length, true)) { + sumSq = 0.0; + for (int i = begin; i < begin + length; i++) { + sumSq += values[i] * values[i]; + } + } + return sumSq; + } + + /** + * {@inheritDoc} + */ + @Override + public SumOfSquares copy() { + SumOfSquares result = new SumOfSquares(); + // no try-catch or advertised exception here because args are valid + copy(this, result); + return result; + } + + /** + * Copies source to dest. + * <p>Neither source nor dest can be null.</p> + * + * @param source SumOfSquares to copy + * @param dest SumOfSquares to copy to + * @throws NullArgumentException if either source or dest is null + */ + public static void copy(SumOfSquares source, SumOfSquares dest) + throws NullArgumentException { + MathUtils.checkNotNull(source); + MathUtils.checkNotNull(dest); + dest.setData(source.getDataRef()); + dest.n = source.n; + dest.value = source.value; + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/descriptive/summary/package-info.java b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/package-info.java new file mode 100644 index 0000000..2f07145 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/descriptive/summary/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Other summary statistics. + */ +package org.apache.commons.math3.stat.descriptive.summary; diff --git a/src/main/java/org/apache/commons/math3/stat/inference/AlternativeHypothesis.java b/src/main/java/org/apache/commons/math3/stat/inference/AlternativeHypothesis.java new file mode 100644 index 0000000..527067e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/AlternativeHypothesis.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +/** + * Represents an alternative hypothesis for a hypothesis test. + * + * @since 3.3 + */ +public enum AlternativeHypothesis { + + /** + * Represents a two-sided test. H0: p=p0, H1: p ≠ p0 + */ + TWO_SIDED, + + /** + * Represents a right-sided test. H0: p ≤ p0, H1: p > p0. + */ + GREATER_THAN, + + /** + * Represents a left-sided test. H0: p ≥ p0, H1: p < p0. + */ + LESS_THAN +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/BinomialTest.java b/src/main/java/org/apache/commons/math3/stat/inference/BinomialTest.java new file mode 100644 index 0000000..2efe091 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/BinomialTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.BinomialDistribution; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MathInternalError; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; + +/** + * Implements binomial test statistics. + * <p> + * Exact test for the statistical significance of deviations from a + * theoretically expected distribution of observations into two categories. + * + * @see <a href="http://en.wikipedia.org/wiki/Binomial_test">Binomial test (Wikipedia)</a> + * @since 3.3 + */ +public class BinomialTest { + + /** + * Returns whether the null hypothesis can be rejected with the given confidence level. + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>Number of trials must be ≥ 0.</li> + * <li>Number of successes must be ≥ 0.</li> + * <li>Number of successes must be ≤ number of trials.</li> + * <li>Probability must be ≥ 0 and ≤ 1.</li> + * </ul> + * + * @param numberOfTrials number of trials performed + * @param numberOfSuccesses number of successes observed + * @param probability assumed probability of a single trial under the null hypothesis + * @param alternativeHypothesis type of hypothesis being evaluated (one- or two-sided) + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with confidence {@code 1 - alpha} + * @throws NotPositiveException if {@code numberOfTrials} or {@code numberOfSuccesses} is negative + * @throws OutOfRangeException if {@code probability} is not between 0 and 1 + * @throws MathIllegalArgumentException if {@code numberOfTrials} < {@code numberOfSuccesses} or + * if {@code alternateHypothesis} is null. + * @see AlternativeHypothesis + */ + public boolean binomialTest(int numberOfTrials, int numberOfSuccesses, double probability, + AlternativeHypothesis alternativeHypothesis, double alpha) { + double pValue = binomialTest(numberOfTrials, numberOfSuccesses, probability, alternativeHypothesis); + return pValue < alpha; + } + + /** + * Returns the <i>observed significance level</i>, or + * <a href="http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue">p-value</a>, + * associated with a <a href="http://en.wikipedia.org/wiki/Binomial_test"> Binomial test</a>. + * <p> + * The number returned is the smallest significance level at which one can reject the null hypothesis. + * The form of the hypothesis depends on {@code alternativeHypothesis}.</p> + * <p> + * The p-Value represents the likelihood of getting a result at least as extreme as the sample, + * given the provided {@code probability} of success on a single trial. For single-sided tests, + * this value can be directly derived from the Binomial distribution. For the two-sided test, + * the implementation works as follows: we start by looking at the most extreme cases + * (0 success and n success where n is the number of trials from the sample) and determine their likelihood. + * The lower value is added to the p-Value (if both values are equal, both are added). Then we continue with + * the next extreme value, until we added the value for the actual observed sample.</p> + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>Number of trials must be ≥ 0.</li> + * <li>Number of successes must be ≥ 0.</li> + * <li>Number of successes must be ≤ number of trials.</li> + * <li>Probability must be ≥ 0 and ≤ 1.</li> + * </ul></p> + * + * @param numberOfTrials number of trials performed + * @param numberOfSuccesses number of successes observed + * @param probability assumed probability of a single trial under the null hypothesis + * @param alternativeHypothesis type of hypothesis being evaluated (one- or two-sided) + * @return p-value + * @throws NotPositiveException if {@code numberOfTrials} or {@code numberOfSuccesses} is negative + * @throws OutOfRangeException if {@code probability} is not between 0 and 1 + * @throws MathIllegalArgumentException if {@code numberOfTrials} < {@code numberOfSuccesses} or + * if {@code alternateHypothesis} is null. + * @see AlternativeHypothesis + */ + public double binomialTest(int numberOfTrials, int numberOfSuccesses, double probability, + AlternativeHypothesis alternativeHypothesis) { + if (numberOfTrials < 0) { + throw new NotPositiveException(numberOfTrials); + } + if (numberOfSuccesses < 0) { + throw new NotPositiveException(numberOfSuccesses); + } + if (probability < 0 || probability > 1) { + throw new OutOfRangeException(probability, 0, 1); + } + if (numberOfTrials < numberOfSuccesses) { + throw new MathIllegalArgumentException( + LocalizedFormats.BINOMIAL_INVALID_PARAMETERS_ORDER, + numberOfTrials, numberOfSuccesses); + } + if (alternativeHypothesis == null) { + throw new NullArgumentException(); + } + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final BinomialDistribution distribution = new BinomialDistribution(null, numberOfTrials, probability); + switch (alternativeHypothesis) { + case GREATER_THAN: + return 1 - distribution.cumulativeProbability(numberOfSuccesses - 1); + case LESS_THAN: + return distribution.cumulativeProbability(numberOfSuccesses); + case TWO_SIDED: + int criticalValueLow = 0; + int criticalValueHigh = numberOfTrials; + double pTotal = 0; + + while (true) { + double pLow = distribution.probability(criticalValueLow); + double pHigh = distribution.probability(criticalValueHigh); + + if (pLow == pHigh) { + pTotal += 2 * pLow; + criticalValueLow++; + criticalValueHigh--; + } else if (pLow < pHigh) { + pTotal += pLow; + criticalValueLow++; + } else { + pTotal += pHigh; + criticalValueHigh--; + } + + if (criticalValueLow > numberOfSuccesses || criticalValueHigh < numberOfSuccesses) { + break; + } + } + return pTotal; + default: + throw new MathInternalError(LocalizedFormats. OUT_OF_RANGE_SIMPLE, alternativeHypothesis, + AlternativeHypothesis.TWO_SIDED, AlternativeHypothesis.LESS_THAN); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/ChiSquareTest.java b/src/main/java/org/apache/commons/math3/stat/inference/ChiSquareTest.java new file mode 100644 index 0000000..7e97ac1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/ChiSquareTest.java @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; + +/** + * Implements Chi-Square test statistics. + * + * <p>This implementation handles both known and unknown distributions.</p> + * + * <p>Two samples tests can be used when the distribution is unknown <i>a priori</i> + * but provided by one sample, or when the hypothesis under test is that the two + * samples come from the same underlying distribution.</p> + * + */ +public class ChiSquareTest { + + /** + * Construct a ChiSquareTest + */ + public ChiSquareTest() { + super(); + } + + /** + * Computes the <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda35f.htm"> + * Chi-Square statistic</a> comparing <code>observed</code> and <code>expected</code> + * frequency counts. + * <p> + * This statistic can be used to perform a Chi-Square test evaluating the null + * hypothesis that the observed counts follow the expected distribution.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. + * </li> + * <li>Observed counts must all be ≥ 0. + * </li> + * <li>The observed and expected arrays must have the same length and + * their common length must be at least 2. + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * <p><strong>Note: </strong>This implementation rescales the + * <code>expected</code> array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @return chiSquare test statistic + * @throws NotPositiveException if <code>observed</code> has negative entries + * @throws NotStrictlyPositiveException if <code>expected</code> has entries that are + * not strictly positive + * @throws DimensionMismatchException if the arrays length is less than 2 + */ + public double chiSquare(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException { + + if (expected.length < 2) { + throw new DimensionMismatchException(expected.length, 2); + } + if (expected.length != observed.length) { + throw new DimensionMismatchException(expected.length, observed.length); + } + MathArrays.checkPositive(expected); + MathArrays.checkNonNegative(observed); + + double sumExpected = 0d; + double sumObserved = 0d; + for (int i = 0; i < observed.length; i++) { + sumExpected += expected[i]; + sumObserved += observed[i]; + } + double ratio = 1.0d; + boolean rescale = false; + if (FastMath.abs(sumExpected - sumObserved) > 10E-6) { + ratio = sumObserved / sumExpected; + rescale = true; + } + double sumSq = 0.0d; + for (int i = 0; i < observed.length; i++) { + if (rescale) { + final double dev = observed[i] - ratio * expected[i]; + sumSq += dev * dev / (ratio * expected[i]); + } else { + final double dev = observed[i] - expected[i]; + sumSq += dev * dev / expected[i]; + } + } + return sumSq; + + } + + /** + * Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a + * <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda35f.htm"> + * Chi-square goodness of fit test</a> comparing the <code>observed</code> + * frequency counts to those in the <code>expected</code> array. + * <p> + * The number returned is the smallest significance level at which one can reject + * the null hypothesis that the observed counts conform to the frequency distribution + * described by the expected counts.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. + * </li> + * <li>Observed counts must all be ≥ 0. + * </li> + * <li>The observed and expected arrays must have the same length and + * their common length must be at least 2. + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * <p><strong>Note: </strong>This implementation rescales the + * <code>expected</code> array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @return p-value + * @throws NotPositiveException if <code>observed</code> has negative entries + * @throws NotStrictlyPositiveException if <code>expected</code> has entries that are + * not strictly positive + * @throws DimensionMismatchException if the arrays length is less than 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double chiSquareTest(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = + new ChiSquaredDistribution(null, expected.length - 1.0); + return 1.0 - distribution.cumulativeProbability(chiSquare(expected, observed)); + } + + /** + * Performs a <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda35f.htm"> + * Chi-square goodness of fit test</a> evaluating the null hypothesis that the + * observed counts conform to the frequency distribution described by the expected + * counts, with significance level <code>alpha</code>. Returns true iff the null + * hypothesis can be rejected with 100 * (1 - alpha) percent confidence. + * <p> + * <strong>Example:</strong><br> + * To test the hypothesis that <code>observed</code> follows + * <code>expected</code> at the 99% level, use </p><p> + * <code>chiSquareTest(expected, observed, 0.01) </code></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. + * </li> + * <li>Observed counts must all be ≥ 0. + * </li> + * <li>The observed and expected arrays must have the same length and + * their common length must be at least 2. + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * <p><strong>Note: </strong>This implementation rescales the + * <code>expected</code> array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @param alpha significance level of the test + * @return true iff null hypothesis can be rejected with confidence + * 1 - alpha + * @throws NotPositiveException if <code>observed</code> has negative entries + * @throws NotStrictlyPositiveException if <code>expected</code> has entries that are + * not strictly positive + * @throws DimensionMismatchException if the arrays length is less than 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean chiSquareTest(final double[] expected, final long[] observed, + final double alpha) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, OutOfRangeException, MaxCountExceededException { + + if ((alpha <= 0) || (alpha > 0.5)) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, + alpha, 0, 0.5); + } + return chiSquareTest(expected, observed) < alpha; + + } + + /** + * Computes the Chi-Square statistic associated with a + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section4/prc45.htm"> + * chi-square test of independence</a> based on the input <code>counts</code> + * array, viewed as a two-way table. + * <p> + * The rows of the 2-way table are + * <code>count[0], ... , count[count.length - 1] </code></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>All counts must be ≥ 0. + * </li> + * <li>The count array must be rectangular (i.e. all count[i] subarrays + * must have the same length). + * </li> + * <li>The 2-way table represented by <code>counts</code> must have at + * least 2 columns and at least 2 rows. + * </li> + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param counts array representation of 2-way table + * @return chiSquare test statistic + * @throws NullArgumentException if the array is null + * @throws DimensionMismatchException if the array is not rectangular + * @throws NotPositiveException if {@code counts} has negative entries + */ + public double chiSquare(final long[][] counts) + throws NullArgumentException, NotPositiveException, + DimensionMismatchException { + + checkArray(counts); + int nRows = counts.length; + int nCols = counts[0].length; + + // compute row, column and total sums + double[] rowSum = new double[nRows]; + double[] colSum = new double[nCols]; + double total = 0.0d; + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < nCols; col++) { + rowSum[row] += counts[row][col]; + colSum[col] += counts[row][col]; + total += counts[row][col]; + } + } + + // compute expected counts and chi-square + double sumSq = 0.0d; + double expected = 0.0d; + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < nCols; col++) { + expected = (rowSum[row] * colSum[col]) / total; + sumSq += ((counts[row][col] - expected) * + (counts[row][col] - expected)) / expected; + } + } + return sumSq; + + } + + /** + * Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section4/prc45.htm"> + * chi-square test of independence</a> based on the input <code>counts</code> + * array, viewed as a two-way table. + * <p> + * The rows of the 2-way table are + * <code>count[0], ... , count[count.length - 1] </code></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>All counts must be ≥ 0. + * </li> + * <li>The count array must be rectangular (i.e. all count[i] subarrays must have + * the same length). + * </li> + * <li>The 2-way table represented by <code>counts</code> must have at least 2 + * columns and at least 2 rows. + * </li> + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param counts array representation of 2-way table + * @return p-value + * @throws NullArgumentException if the array is null + * @throws DimensionMismatchException if the array is not rectangular + * @throws NotPositiveException if {@code counts} has negative entries + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double chiSquareTest(final long[][] counts) + throws NullArgumentException, DimensionMismatchException, + NotPositiveException, MaxCountExceededException { + + checkArray(counts); + double df = ((double) counts.length -1) * ((double) counts[0].length - 1); + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = new ChiSquaredDistribution(df); + return 1 - distribution.cumulativeProbability(chiSquare(counts)); + + } + + /** + * Performs a <a href="http://www.itl.nist.gov/div898/handbook/prc/section4/prc45.htm"> + * chi-square test of independence</a> evaluating the null hypothesis that the + * classifications represented by the counts in the columns of the input 2-way table + * are independent of the rows, with significance level <code>alpha</code>. + * Returns true iff the null hypothesis can be rejected with 100 * (1 - alpha) percent + * confidence. + * <p> + * The rows of the 2-way table are + * <code>count[0], ... , count[count.length - 1] </code></p> + * <p> + * <strong>Example:</strong><br> + * To test the null hypothesis that the counts in + * <code>count[0], ... , count[count.length - 1] </code> + * all correspond to the same underlying probability distribution at the 99% level, use</p> + * <p><code>chiSquareTest(counts, 0.01)</code></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>All counts must be ≥ 0. + * </li> + * <li>The count array must be rectangular (i.e. all count[i] subarrays must have the + * same length).</li> + * <li>The 2-way table represented by <code>counts</code> must have at least 2 columns and + * at least 2 rows.</li> + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param counts array representation of 2-way table + * @param alpha significance level of the test + * @return true iff null hypothesis can be rejected with confidence + * 1 - alpha + * @throws NullArgumentException if the array is null + * @throws DimensionMismatchException if the array is not rectangular + * @throws NotPositiveException if {@code counts} has any negative entries + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean chiSquareTest(final long[][] counts, final double alpha) + throws NullArgumentException, DimensionMismatchException, + NotPositiveException, OutOfRangeException, MaxCountExceededException { + + if ((alpha <= 0) || (alpha > 0.5)) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, + alpha, 0, 0.5); + } + return chiSquareTest(counts) < alpha; + + } + + /** + * <p>Computes a + * <a href="http://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/chi2samp.htm"> + * Chi-Square two sample test statistic</a> comparing bin frequency counts + * in <code>observed1</code> and <code>observed2</code>. The + * sums of frequency counts in the two samples are not required to be the + * same. The formula used to compute the test statistic is</p> + * <code> + * ∑[(K * observed1[i] - observed2[i]/K)<sup>2</sup> / (observed1[i] + observed2[i])] + * </code> where + * <br/><code>K = &sqrt;[&sum(observed2 / ∑(observed1)]</code> + * </p> + * <p>This statistic can be used to perform a Chi-Square test evaluating the + * null hypothesis that both observed counts follow the same distribution.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>Observed counts must be non-negative. + * </li> + * <li>Observed counts for a specific bin must not both be zero. + * </li> + * <li>Observed counts for a specific sample must not all be 0. + * </li> + * <li>The arrays <code>observed1</code> and <code>observed2</code> must have + * the same length and their common length must be at least 2. + * </li></ul></p><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data set + * @return chiSquare test statistic + * @throws DimensionMismatchException the the length of the arrays does not match + * @throws NotPositiveException if any entries in <code>observed1</code> or + * <code>observed2</code> are negative + * @throws ZeroException if either all counts of <code>observed1</code> or + * <code>observed2</code> are zero, or if the count at some index is zero + * for both arrays + * @since 1.2 + */ + public double chiSquareDataSetsComparison(long[] observed1, long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException { + + // Make sure lengths are same + if (observed1.length < 2) { + throw new DimensionMismatchException(observed1.length, 2); + } + if (observed1.length != observed2.length) { + throw new DimensionMismatchException(observed1.length, observed2.length); + } + + // Ensure non-negative counts + MathArrays.checkNonNegative(observed1); + MathArrays.checkNonNegative(observed2); + + // Compute and compare count sums + long countSum1 = 0; + long countSum2 = 0; + boolean unequalCounts = false; + double weight = 0.0; + for (int i = 0; i < observed1.length; i++) { + countSum1 += observed1[i]; + countSum2 += observed2[i]; + } + // Ensure neither sample is uniformly 0 + if (countSum1 == 0 || countSum2 == 0) { + throw new ZeroException(); + } + // Compare and compute weight only if different + unequalCounts = countSum1 != countSum2; + if (unequalCounts) { + weight = FastMath.sqrt((double) countSum1 / (double) countSum2); + } + // Compute ChiSquare statistic + double sumSq = 0.0d; + double dev = 0.0d; + double obs1 = 0.0d; + double obs2 = 0.0d; + for (int i = 0; i < observed1.length; i++) { + if (observed1[i] == 0 && observed2[i] == 0) { + throw new ZeroException(LocalizedFormats.OBSERVED_COUNTS_BOTTH_ZERO_FOR_ENTRY, i); + } else { + obs1 = observed1[i]; + obs2 = observed2[i]; + if (unequalCounts) { // apply weights + dev = obs1/weight - obs2 * weight; + } else { + dev = obs1 - obs2; + } + sumSq += (dev * dev) / (obs1 + obs2); + } + } + return sumSq; + } + + /** + * <p>Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a Chi-Square two sample test comparing + * bin frequency counts in <code>observed1</code> and + * <code>observed2</code>. + * </p> + * <p>The number returned is the smallest significance level at which one + * can reject the null hypothesis that the observed counts conform to the + * same distribution. + * </p> + * <p>See {@link #chiSquareDataSetsComparison(long[], long[])} for details + * on the formula used to compute the test statistic. The degrees of + * of freedom used to perform the test is one less than the common length + * of the input observed count arrays. + * </p> + * <strong>Preconditions</strong>: <ul> + * <li>Observed counts must be non-negative. + * </li> + * <li>Observed counts for a specific bin must not both be zero. + * </li> + * <li>Observed counts for a specific sample must not all be 0. + * </li> + * <li>The arrays <code>observed1</code> and <code>observed2</code> must + * have the same length and + * their common length must be at least 2. + * </li></ul><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data set + * @return p-value + * @throws DimensionMismatchException the the length of the arrays does not match + * @throws NotPositiveException if any entries in <code>observed1</code> or + * <code>observed2</code> are negative + * @throws ZeroException if either all counts of <code>observed1</code> or + * <code>observed2</code> are zero, or if the count at the same index is zero + * for both arrays + * @throws MaxCountExceededException if an error occurs computing the p-value + * @since 1.2 + */ + public double chiSquareTestDataSetsComparison(long[] observed1, long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException, + MaxCountExceededException { + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = + new ChiSquaredDistribution(null, (double) observed1.length - 1); + return 1 - distribution.cumulativeProbability( + chiSquareDataSetsComparison(observed1, observed2)); + + } + + /** + * <p>Performs a Chi-Square two sample test comparing two binned data + * sets. The test evaluates the null hypothesis that the two lists of + * observed counts conform to the same frequency distribution, with + * significance level <code>alpha</code>. Returns true iff the null + * hypothesis can be rejected with 100 * (1 - alpha) percent confidence. + * </p> + * <p>See {@link #chiSquareDataSetsComparison(long[], long[])} for + * details on the formula used to compute the Chisquare statistic used + * in the test. The degrees of of freedom used to perform the test is + * one less than the common length of the input observed count arrays. + * </p> + * <strong>Preconditions</strong>: <ul> + * <li>Observed counts must be non-negative. + * </li> + * <li>Observed counts for a specific bin must not both be zero. + * </li> + * <li>Observed counts for a specific sample must not all be 0. + * </li> + * <li>The arrays <code>observed1</code> and <code>observed2</code> must + * have the same length and their common length must be at least 2. + * </li> + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul><p> + * If any of the preconditions are not met, an + * <code>IllegalArgumentException</code> is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data set + * @param alpha significance level of the test + * @return true iff null hypothesis can be rejected with confidence + * 1 - alpha + * @throws DimensionMismatchException the the length of the arrays does not match + * @throws NotPositiveException if any entries in <code>observed1</code> or + * <code>observed2</code> are negative + * @throws ZeroException if either all counts of <code>observed1</code> or + * <code>observed2</code> are zero, or if the count at the same index is zero + * for both arrays + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs performing the test + * @since 1.2 + */ + public boolean chiSquareTestDataSetsComparison(final long[] observed1, + final long[] observed2, + final double alpha) + throws DimensionMismatchException, NotPositiveException, + ZeroException, OutOfRangeException, MaxCountExceededException { + + if (alpha <= 0 || + alpha > 0.5) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, + alpha, 0, 0.5); + } + return chiSquareTestDataSetsComparison(observed1, observed2) < alpha; + + } + + /** + * Checks to make sure that the input long[][] array is rectangular, + * has at least 2 rows and 2 columns, and has all non-negative entries. + * + * @param in input 2-way table to check + * @throws NullArgumentException if the array is null + * @throws DimensionMismatchException if the array is not valid + * @throws NotPositiveException if the array contains any negative entries + */ + private void checkArray(final long[][] in) + throws NullArgumentException, DimensionMismatchException, + NotPositiveException { + + if (in.length < 2) { + throw new DimensionMismatchException(in.length, 2); + } + + if (in[0].length < 2) { + throw new DimensionMismatchException(in[0].length, 2); + } + + MathArrays.checkRectangular(in); + MathArrays.checkNonNegative(in); + + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/GTest.java b/src/main/java/org/apache/commons/math3/stat/inference/GTest.java new file mode 100644 index 0000000..de1fbe3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/GTest.java @@ -0,0 +1,538 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; + +/** + * Implements <a href="http://en.wikipedia.org/wiki/G-test">G Test</a> + * statistics. + * + * <p>This is known in statistical genetics as the McDonald-Kreitman test. + * The implementation handles both known and unknown distributions.</p> + * + * <p>Two samples tests can be used when the distribution is unknown <i>a priori</i> + * but provided by one sample, or when the hypothesis under test is that the two + * samples come from the same underlying distribution.</p> + * + * @since 3.1 + */ +public class GTest { + + /** + * Computes the <a href="http://en.wikipedia.org/wiki/G-test">G statistic + * for Goodness of Fit</a> comparing {@code observed} and {@code expected} + * frequency counts. + * + * <p>This statistic can be used to perform a G test (Log-Likelihood Ratio + * Test) evaluating the null hypothesis that the observed counts follow the + * expected distribution.</p> + * + * <p><strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. </li> + * <li>Observed counts must all be ≥ 0. </li> + * <li>The observed and expected arrays must have the same length and their + * common length must be at least 2. </li></ul></p> + * + * <p>If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * <p><strong>Note:</strong>This implementation rescales the + * {@code expected} array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @return G-Test statistic + * @throws NotPositiveException if {@code observed} has negative entries + * @throws NotStrictlyPositiveException if {@code expected} has entries that + * are not strictly positive + * @throws DimensionMismatchException if the array lengths do not match or + * are less than 2. + */ + public double g(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException { + + if (expected.length < 2) { + throw new DimensionMismatchException(expected.length, 2); + } + if (expected.length != observed.length) { + throw new DimensionMismatchException(expected.length, observed.length); + } + MathArrays.checkPositive(expected); + MathArrays.checkNonNegative(observed); + + double sumExpected = 0d; + double sumObserved = 0d; + for (int i = 0; i < observed.length; i++) { + sumExpected += expected[i]; + sumObserved += observed[i]; + } + double ratio = 1d; + boolean rescale = false; + if (FastMath.abs(sumExpected - sumObserved) > 10E-6) { + ratio = sumObserved / sumExpected; + rescale = true; + } + double sum = 0d; + for (int i = 0; i < observed.length; i++) { + final double dev = rescale ? + FastMath.log((double) observed[i] / (ratio * expected[i])) : + FastMath.log((double) observed[i] / expected[i]); + sum += ((double) observed[i]) * dev; + } + return 2d * sum; + } + + /** + * Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> p-value</a>, + * associated with a G-Test for goodness of fit</a> comparing the + * {@code observed} frequency counts to those in the {@code expected} array. + * + * <p>The number returned is the smallest significance level at which one + * can reject the null hypothesis that the observed counts conform to the + * frequency distribution described by the expected counts.</p> + * + * <p>The probability returned is the tail probability beyond + * {@link #g(double[], long[]) g(expected, observed)} + * in the ChiSquare distribution with degrees of freedom one less than the + * common length of {@code expected} and {@code observed}.</p> + * + * <p> <strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. </li> + * <li>Observed counts must all be ≥ 0. </li> + * <li>The observed and expected arrays must have the + * same length and their common length must be at least 2.</li> + * </ul></p> + * + * <p>If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * <p><strong>Note:</strong>This implementation rescales the + * {@code expected} array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @return p-value + * @throws NotPositiveException if {@code observed} has negative entries + * @throws NotStrictlyPositiveException if {@code expected} has entries that + * are not strictly positive + * @throws DimensionMismatchException if the array lengths do not match or + * are less than 2. + * @throws MaxCountExceededException if an error occurs computing the + * p-value. + */ + public double gTest(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = + new ChiSquaredDistribution(null, expected.length - 1.0); + return 1.0 - distribution.cumulativeProbability(g(expected, observed)); + } + + /** + * Returns the intrinsic (Hardy-Weinberg proportions) p-Value, as described + * in p64-69 of McDonald, J.H. 2009. Handbook of Biological Statistics + * (2nd ed.). Sparky House Publishing, Baltimore, Maryland. + * + * <p> The probability returned is the tail probability beyond + * {@link #g(double[], long[]) g(expected, observed)} + * in the ChiSquare distribution with degrees of freedom two less than the + * common length of {@code expected} and {@code observed}.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @return p-value + * @throws NotPositiveException if {@code observed} has negative entries + * @throws NotStrictlyPositiveException {@code expected} has entries that are + * not strictly positive + * @throws DimensionMismatchException if the array lengths do not match or + * are less than 2. + * @throws MaxCountExceededException if an error occurs computing the + * p-value. + */ + public double gTestIntrinsic(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = + new ChiSquaredDistribution(null, expected.length - 2.0); + return 1.0 - distribution.cumulativeProbability(g(expected, observed)); + } + + /** + * Performs a G-Test (Log-Likelihood Ratio Test) for goodness of fit + * evaluating the null hypothesis that the observed counts conform to the + * frequency distribution described by the expected counts, with + * significance level {@code alpha}. Returns true iff the null + * hypothesis can be rejected with {@code 100 * (1 - alpha)} percent confidence. + * + * <p><strong>Example:</strong><br> To test the hypothesis that + * {@code observed} follows {@code expected} at the 99% level, + * use </p><p> + * {@code gTest(expected, observed, 0.01)}</p> + * + * <p>Returns true iff {@link #gTest(double[], long[]) + * gTestGoodnessOfFitPValue(expected, observed)} < alpha</p> + * + * <p><strong>Preconditions</strong>: <ul> + * <li>Expected counts must all be positive. </li> + * <li>Observed counts must all be ≥ 0. </li> + * <li>The observed and expected arrays must have the same length and their + * common length must be at least 2. + * <li> {@code 0 < alpha < 0.5} </li></ul></p> + * + * <p>If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * <p><strong>Note:</strong>This implementation rescales the + * {@code expected} array if necessary to ensure that the sum of the + * expected and observed counts are equal.</p> + * + * @param observed array of observed frequency counts + * @param expected array of expected frequency counts + * @param alpha significance level of the test + * @return true iff null hypothesis can be rejected with confidence 1 - + * alpha + * @throws NotPositiveException if {@code observed} has negative entries + * @throws NotStrictlyPositiveException if {@code expected} has entries that + * are not strictly positive + * @throws DimensionMismatchException if the array lengths do not match or + * are less than 2. + * @throws MaxCountExceededException if an error occurs computing the + * p-value. + * @throws OutOfRangeException if alpha is not strictly greater than zero + * and less than or equal to 0.5 + */ + public boolean gTest(final double[] expected, final long[] observed, + final double alpha) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, OutOfRangeException, MaxCountExceededException { + + if ((alpha <= 0) || (alpha > 0.5)) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, + alpha, 0, 0.5); + } + return gTest(expected, observed) < alpha; + } + + /** + * Calculates the <a href= + * "http://en.wikipedia.org/wiki/Entropy_%28information_theory%29">Shannon + * entropy</a> for 2 Dimensional Matrix. The value returned is the entropy + * of the vector formed by concatenating the rows (or columns) of {@code k} + * to form a vector. See {@link #entropy(long[])}. + * + * @param k 2 Dimensional Matrix of long values (for ex. the counts of a + * trials) + * @return Shannon Entropy of the given Matrix + * + */ + private double entropy(final long[][] k) { + double h = 0d; + double sum_k = 0d; + for (int i = 0; i < k.length; i++) { + for (int j = 0; j < k[i].length; j++) { + sum_k += (double) k[i][j]; + } + } + for (int i = 0; i < k.length; i++) { + for (int j = 0; j < k[i].length; j++) { + if (k[i][j] != 0) { + final double p_ij = (double) k[i][j] / sum_k; + h += p_ij * FastMath.log(p_ij); + } + } + } + return -h; + } + + /** + * Calculates the <a href="http://en.wikipedia.org/wiki/Entropy_%28information_theory%29"> + * Shannon entropy</a> for a vector. The values of {@code k} are taken to be + * incidence counts of the values of a random variable. What is returned is <br/> + * ∑p<sub>i</sub>log(p<sub>i</sub><br/> + * where p<sub>i</sub> = k[i] / (sum of elements in k) + * + * @param k Vector (for ex. Row Sums of a trials) + * @return Shannon Entropy of the given Vector + * + */ + private double entropy(final long[] k) { + double h = 0d; + double sum_k = 0d; + for (int i = 0; i < k.length; i++) { + sum_k += (double) k[i]; + } + for (int i = 0; i < k.length; i++) { + if (k[i] != 0) { + final double p_i = (double) k[i] / sum_k; + h += p_i * FastMath.log(p_i); + } + } + return -h; + } + + /** + * <p>Computes a G (Log-Likelihood Ratio) two sample test statistic for + * independence comparing frequency counts in + * {@code observed1} and {@code observed2}. The sums of frequency + * counts in the two samples are not required to be the same. The formula + * used to compute the test statistic is </p> + * + * <p>{@code 2 * totalSum * [H(rowSums) + H(colSums) - H(k)]}</p> + * + * <p> where {@code H} is the + * <a href="http://en.wikipedia.org/wiki/Entropy_%28information_theory%29"> + * Shannon Entropy</a> of the random variable formed by viewing the elements + * of the argument array as incidence counts; <br/> + * {@code k} is a matrix with rows {@code [observed1, observed2]}; <br/> + * {@code rowSums, colSums} are the row/col sums of {@code k}; <br> + * and {@code totalSum} is the overall sum of all entries in {@code k}.</p> + * + * <p>This statistic can be used to perform a G test evaluating the null + * hypothesis that both observed counts are independent </p> + * + * <p> <strong>Preconditions</strong>: <ul> + * <li>Observed counts must be non-negative. </li> + * <li>Observed counts for a specific bin must not both be zero. </li> + * <li>Observed counts for a specific sample must not all be 0. </li> + * <li>The arrays {@code observed1} and {@code observed2} must have + * the same length and their common length must be at least 2. </li></ul></p> + * + * <p>If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data + * set + * @return G-Test statistic + * @throws DimensionMismatchException the the lengths of the arrays do not + * match or their common length is less than 2 + * @throws NotPositiveException if any entry in {@code observed1} or + * {@code observed2} is negative + * @throws ZeroException if either all counts of + * {@code observed1} or {@code observed2} are zero, or if the count + * at the same index is zero for both arrays. + */ + public double gDataSetsComparison(final long[] observed1, final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException { + + // Make sure lengths are same + if (observed1.length < 2) { + throw new DimensionMismatchException(observed1.length, 2); + } + if (observed1.length != observed2.length) { + throw new DimensionMismatchException(observed1.length, observed2.length); + } + + // Ensure non-negative counts + MathArrays.checkNonNegative(observed1); + MathArrays.checkNonNegative(observed2); + + // Compute and compare count sums + long countSum1 = 0; + long countSum2 = 0; + + // Compute and compare count sums + final long[] collSums = new long[observed1.length]; + final long[][] k = new long[2][observed1.length]; + + for (int i = 0; i < observed1.length; i++) { + if (observed1[i] == 0 && observed2[i] == 0) { + throw new ZeroException(LocalizedFormats.OBSERVED_COUNTS_BOTTH_ZERO_FOR_ENTRY, i); + } else { + countSum1 += observed1[i]; + countSum2 += observed2[i]; + collSums[i] = observed1[i] + observed2[i]; + k[0][i] = observed1[i]; + k[1][i] = observed2[i]; + } + } + // Ensure neither sample is uniformly 0 + if (countSum1 == 0 || countSum2 == 0) { + throw new ZeroException(); + } + final long[] rowSums = {countSum1, countSum2}; + final double sum = (double) countSum1 + (double) countSum2; + return 2 * sum * (entropy(rowSums) + entropy(collSums) - entropy(k)); + } + + /** + * Calculates the root log-likelihood ratio for 2 state Datasets. See + * {@link #gDataSetsComparison(long[], long[] )}. + * + * <p>Given two events A and B, let k11 be the number of times both events + * occur, k12 the incidence of B without A, k21 the count of A without B, + * and k22 the number of times neither A nor B occurs. What is returned + * by this method is </p> + * + * <p>{@code (sgn) sqrt(gValueDataSetsComparison({k11, k12}, {k21, k22})}</p> + * + * <p>where {@code sgn} is -1 if {@code k11 / (k11 + k12) < k21 / (k21 + k22))};<br/> + * 1 otherwise.</p> + * + * <p>Signed root LLR has two advantages over the basic LLR: a) it is positive + * where k11 is bigger than expected, negative where it is lower b) if there is + * no difference it is asymptotically normally distributed. This allows one + * to talk about "number of standard deviations" which is a more common frame + * of reference than the chi^2 distribution.</p> + * + * @param k11 number of times the two events occurred together (AB) + * @param k12 number of times the second event occurred WITHOUT the + * first event (notA,B) + * @param k21 number of times the first event occurred WITHOUT the + * second event (A, notB) + * @param k22 number of times something else occurred (i.e. was neither + * of these events (notA, notB) + * @return root log-likelihood ratio + * + */ + public double rootLogLikelihoodRatio(final long k11, long k12, + final long k21, final long k22) { + final double llr = gDataSetsComparison( + new long[]{k11, k12}, new long[]{k21, k22}); + double sqrt = FastMath.sqrt(llr); + if ((double) k11 / (k11 + k12) < (double) k21 / (k21 + k22)) { + sqrt = -sqrt; + } + return sqrt; + } + + /** + * <p>Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a G-Value (Log-Likelihood Ratio) for two + * sample test comparing bin frequency counts in {@code observed1} and + * {@code observed2}.</p> + * + * <p>The number returned is the smallest significance level at which one + * can reject the null hypothesis that the observed counts conform to the + * same distribution. </p> + * + * <p>See {@link #gTest(double[], long[])} for details + * on how the p-value is computed. The degrees of of freedom used to + * perform the test is one less than the common length of the input observed + * count arrays.</p> + * + * <p><strong>Preconditions</strong>: + * <ul> <li>Observed counts must be non-negative. </li> + * <li>Observed counts for a specific bin must not both be zero. </li> + * <li>Observed counts for a specific sample must not all be 0. </li> + * <li>The arrays {@code observed1} and {@code observed2} must + * have the same length and their common length must be at least 2. </li> + * </ul><p> + * <p> If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data + * set + * @return p-value + * @throws DimensionMismatchException the the length of the arrays does not + * match or their common length is less than 2 + * @throws NotPositiveException if any of the entries in {@code observed1} or + * {@code observed2} are negative + * @throws ZeroException if either all counts of {@code observed1} or + * {@code observed2} are zero, or if the count at some index is + * zero for both arrays + * @throws MaxCountExceededException if an error occurs computing the + * p-value. + */ + public double gTestDataSetsComparison(final long[] observed1, + final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException, + MaxCountExceededException { + + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final ChiSquaredDistribution distribution = + new ChiSquaredDistribution(null, (double) observed1.length - 1); + return 1 - distribution.cumulativeProbability( + gDataSetsComparison(observed1, observed2)); + } + + /** + * <p>Performs a G-Test (Log-Likelihood Ratio Test) comparing two binned + * data sets. The test evaluates the null hypothesis that the two lists + * of observed counts conform to the same frequency distribution, with + * significance level {@code alpha}. Returns true iff the null + * hypothesis can be rejected with 100 * (1 - alpha) percent confidence. + * </p> + * <p>See {@link #gDataSetsComparison(long[], long[])} for details + * on the formula used to compute the G (LLR) statistic used in the test and + * {@link #gTest(double[], long[])} for information on how + * the observed significance level is computed. The degrees of of freedom used + * to perform the test is one less than the common length of the input observed + * count arrays. </p> + * + * <strong>Preconditions</strong>: <ul> + * <li>Observed counts must be non-negative. </li> + * <li>Observed counts for a specific bin must not both be zero. </li> + * <li>Observed counts for a specific sample must not all be 0. </li> + * <li>The arrays {@code observed1} and {@code observed2} must + * have the same length and their common length must be at least 2. </li> + * <li>{@code 0 < alpha < 0.5} </li></ul></p> + * + * <p>If any of the preconditions are not met, a + * {@code MathIllegalArgumentException} is thrown.</p> + * + * @param observed1 array of observed frequency counts of the first data set + * @param observed2 array of observed frequency counts of the second data + * set + * @param alpha significance level of the test + * @return true iff null hypothesis can be rejected with confidence 1 - + * alpha + * @throws DimensionMismatchException the the length of the arrays does not + * match + * @throws NotPositiveException if any of the entries in {@code observed1} or + * {@code observed2} are negative + * @throws ZeroException if either all counts of {@code observed1} or + * {@code observed2} are zero, or if the count at some index is + * zero for both arrays + * @throws OutOfRangeException if {@code alpha} is not in the range + * (0, 0.5] + * @throws MaxCountExceededException if an error occurs performing the test + */ + public boolean gTestDataSetsComparison( + final long[] observed1, + final long[] observed2, + final double alpha) + throws DimensionMismatchException, NotPositiveException, + ZeroException, OutOfRangeException, MaxCountExceededException { + + if (alpha <= 0 || alpha > 0.5) { + throw new OutOfRangeException( + LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, alpha, 0, 0.5); + } + return gTestDataSetsComparison(observed1, observed2) < alpha; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/KolmogorovSmirnovTest.java b/src/main/java/org/apache/commons/math3/stat/inference/KolmogorovSmirnovTest.java new file mode 100644 index 0000000..6b70e9b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/KolmogorovSmirnovTest.java @@ -0,0 +1,1270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.inference; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.HashSet; + +import org.apache.commons.math3.distribution.EnumeratedRealDistribution; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.apache.commons.math3.exception.InsufficientDataException; +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.MathInternalError; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.TooManyIterationsException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.fraction.BigFraction; +import org.apache.commons.math3.fraction.BigFractionField; +import org.apache.commons.math3.fraction.FractionConversionException; +import org.apache.commons.math3.linear.Array2DRowFieldMatrix; +import org.apache.commons.math3.linear.FieldMatrix; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.CombinatoricsUtils; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; + +/** + * Implementation of the <a href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> + * Kolmogorov-Smirnov (K-S) test</a> for equality of continuous distributions. + * <p> + * The K-S test uses a statistic based on the maximum deviation of the empirical distribution of + * sample data points from the distribution expected under the null hypothesis. For one-sample tests + * evaluating the null hypothesis that a set of sample data points follow a given distribution, the + * test statistic is \(D_n=\sup_x |F_n(x)-F(x)|\), where \(F\) is the expected distribution and + * \(F_n\) is the empirical distribution of the \(n\) sample data points. The distribution of + * \(D_n\) is estimated using a method based on [1] with certain quick decisions for extreme values + * given in [2]. + * </p> + * <p> + * Two-sample tests are also supported, evaluating the null hypothesis that the two samples + * {@code x} and {@code y} come from the same underlying distribution. In this case, the test + * statistic is \(D_{n,m}=\sup_t | F_n(t)-F_m(t)|\) where \(n\) is the length of {@code x}, \(m\) is + * the length of {@code y}, \(F_n\) is the empirical distribution that puts mass \(1/n\) at each of + * the values in {@code x} and \(F_m\) is the empirical distribution of the {@code y} values. The + * default 2-sample test method, {@link #kolmogorovSmirnovTest(double[], double[])} works as + * follows: + * <ul> + * <li>For small samples (where the product of the sample sizes is less than + * {@value #LARGE_SAMPLE_PRODUCT}), the method presented in [4] is used to compute the + * exact p-value for the 2-sample test.</li> + * <li>When the product of the sample sizes exceeds {@value #LARGE_SAMPLE_PRODUCT}, the asymptotic + * distribution of \(D_{n,m}\) is used. See {@link #approximateP(double, int, int)} for details on + * the approximation.</li> + * </ul></p><p> + * If the product of the sample sizes is less than {@value #LARGE_SAMPLE_PRODUCT} and the sample + * data contains ties, random jitter is added to the sample data to break ties before applying + * the algorithm above. Alternatively, the {@link #bootstrap(double[], double[], int, boolean)} + * method, modeled after <a href="http://sekhon.berkeley.edu/matching/ks.boot.html">ks.boot</a> + * in the R Matching package [3], can be used if ties are known to be present in the data. + * </p> + * <p> + * In the two-sample case, \(D_{n,m}\) has a discrete distribution. This makes the p-value + * associated with the null hypothesis \(H_0 : D_{n,m} \ge d \) differ from \(H_0 : D_{n,m} > d \) + * by the mass of the observed value \(d\). To distinguish these, the two-sample tests use a boolean + * {@code strict} parameter. This parameter is ignored for large samples. + * </p> + * <p> + * The methods used by the 2-sample default implementation are also exposed directly: + * <ul> + * <li>{@link #exactP(double, int, int, boolean)} computes exact 2-sample p-values</li> + * <li>{@link #approximateP(double, int, int)} uses the asymptotic distribution The {@code boolean} + * arguments in the first two methods allow the probability used to estimate the p-value to be + * expressed using strict or non-strict inequality. See + * {@link #kolmogorovSmirnovTest(double[], double[], boolean)}.</li> + * </ul> + * </p> + * <p> + * References: + * <ul> + * <li>[1] <a href="http://www.jstatsoft.org/v08/i18/"> Evaluating Kolmogorov's Distribution</a> by + * George Marsaglia, Wai Wan Tsang, and Jingbo Wang</li> + * <li>[2] <a href="http://www.jstatsoft.org/v39/i11/"> Computing the Two-Sided Kolmogorov-Smirnov + * Distribution</a> by Richard Simard and Pierre L'Ecuyer</li> + * <li>[3] Jasjeet S. Sekhon. 2011. <a href="http://www.jstatsoft.org/article/view/v042i07"> + * Multivariate and Propensity Score Matching Software with Automated Balance Optimization: + * The Matching package for R</a> Journal of Statistical Software, 42(7): 1-52.</li> + * <li>[4] Wilcox, Rand. 2012. Introduction to Robust Estimation and Hypothesis Testing, + * Chapter 5, 3rd Ed. Academic Press.</li> + * </ul> + * <br/> + * Note that [1] contains an error in computing h, refer to <a + * href="https://issues.apache.org/jira/browse/MATH-437">MATH-437</a> for details. + * </p> + * + * @since 3.3 + */ +public class KolmogorovSmirnovTest { + + /** + * Bound on the number of partial sums in {@link #ksSum(double, double, int)} + */ + protected static final int MAXIMUM_PARTIAL_SUM_COUNT = 100000; + + /** Convergence criterion for {@link #ksSum(double, double, int)} */ + protected static final double KS_SUM_CAUCHY_CRITERION = 1E-20; + + /** Convergence criterion for the sums in #pelzGood(double, double, int)} */ + protected static final double PG_SUM_RELATIVE_ERROR = 1.0e-10; + + /** No longer used. */ + @Deprecated + protected static final int SMALL_SAMPLE_PRODUCT = 200; + + /** + * When product of sample sizes exceeds this value, 2-sample K-S test uses asymptotic + * distribution to compute the p-value. + */ + protected static final int LARGE_SAMPLE_PRODUCT = 10000; + + /** Default number of iterations used by {@link #monteCarloP(double, int, int, boolean, int)}. + * Deprecated as of version 3.6, as this method is no longer needed. */ + @Deprecated + protected static final int MONTE_CARLO_ITERATIONS = 1000000; + + /** Random data generator used by {@link #monteCarloP(double, int, int, boolean, int)} */ + private final RandomGenerator rng; + + /** + * Construct a KolmogorovSmirnovTest instance with a default random data generator. + */ + public KolmogorovSmirnovTest() { + rng = new Well19937c(); + } + + /** + * Construct a KolmogorovSmirnovTest with the provided random data generator. + * The #monteCarloP(double, int, int, boolean, int) that uses the generator supplied to this + * constructor is deprecated as of version 3.6. + * + * @param rng random data generator used by {@link #monteCarloP(double, int, int, boolean, int)} + */ + @Deprecated + public KolmogorovSmirnovTest(RandomGenerator rng) { + this.rng = rng; + } + + /** + * Computes the <i>p-value</i>, or <i>observed significance level</i>, of a one-sample <a + * href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov test</a> + * evaluating the null hypothesis that {@code data} conforms to {@code distribution}. If + * {@code exact} is true, the distribution used to compute the p-value is computed using + * extended precision. See {@link #cdfExact(double, int)}. + * + * @param distribution reference distribution + * @param data sample being being evaluated + * @param exact whether or not to force exact computation of the p-value + * @return the p-value associated with the null hypothesis that {@code data} is a sample from + * {@code distribution} + * @throws InsufficientDataException if {@code data} does not have length at least 2 + * @throws NullArgumentException if {@code data} is null + */ + public double kolmogorovSmirnovTest(RealDistribution distribution, double[] data, boolean exact) { + return 1d - cdf(kolmogorovSmirnovStatistic(distribution, data), data.length, exact); + } + + /** + * Computes the one-sample Kolmogorov-Smirnov test statistic, \(D_n=\sup_x |F_n(x)-F(x)|\) where + * \(F\) is the distribution (cdf) function associated with {@code distribution}, \(n\) is the + * length of {@code data} and \(F_n\) is the empirical distribution that puts mass \(1/n\) at + * each of the values in {@code data}. + * + * @param distribution reference distribution + * @param data sample being evaluated + * @return Kolmogorov-Smirnov statistic \(D_n\) + * @throws InsufficientDataException if {@code data} does not have length at least 2 + * @throws NullArgumentException if {@code data} is null + */ + public double kolmogorovSmirnovStatistic(RealDistribution distribution, double[] data) { + checkArray(data); + final int n = data.length; + final double nd = n; + final double[] dataCopy = new double[n]; + System.arraycopy(data, 0, dataCopy, 0, n); + Arrays.sort(dataCopy); + double d = 0d; + for (int i = 1; i <= n; i++) { + final double yi = distribution.cumulativeProbability(dataCopy[i - 1]); + final double currD = FastMath.max(yi - (i - 1) / nd, i / nd - yi); + if (currD > d) { + d = currD; + } + } + return d; + } + + /** + * Computes the <i>p-value</i>, or <i>observed significance level</i>, of a two-sample <a + * href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov test</a> + * evaluating the null hypothesis that {@code x} and {@code y} are samples drawn from the same + * probability distribution. Specifically, what is returned is an estimate of the probability + * that the {@link #kolmogorovSmirnovStatistic(double[], double[])} associated with a randomly + * selected partition of the combined sample into subsamples of sizes {@code x.length} and + * {@code y.length} will strictly exceed (if {@code strict} is {@code true}) or be at least as + * large as {@code strict = false}) as {@code kolmogorovSmirnovStatistic(x, y)}. + * <ul> + * <li>For small samples (where the product of the sample sizes is less than + * {@value #LARGE_SAMPLE_PRODUCT}), the exact p-value is computed using the method presented + * in [4], implemented in {@link #exactP(double, int, int, boolean)}. </li> + * <li>When the product of the sample sizes exceeds {@value #LARGE_SAMPLE_PRODUCT}, the + * asymptotic distribution of \(D_{n,m}\) is used. See {@link #approximateP(double, int, int)} + * for details on the approximation.</li> + * </ul><p> + * If {@code x.length * y.length} < {@value #LARGE_SAMPLE_PRODUCT} and the combined set of values in + * {@code x} and {@code y} contains ties, random jitter is added to {@code x} and {@code y} to + * break ties before computing \(D_{n,m}\) and the p-value. The jitter is uniformly distributed + * on (-minDelta / 2, minDelta / 2) where minDelta is the smallest pairwise difference between + * values in the combined sample.</p> + * <p> + * If ties are known to be present in the data, {@link #bootstrap(double[], double[], int, boolean)} + * may be used as an alternative method for estimating the p-value.</p> + * + * @param x first sample dataset + * @param y second sample dataset + * @param strict whether or not the probability to compute is expressed as a strict inequality + * (ignored for large samples) + * @return p-value associated with the null hypothesis that {@code x} and {@code y} represent + * samples from the same distribution + * @throws InsufficientDataException if either {@code x} or {@code y} does not have length at + * least 2 + * @throws NullArgumentException if either {@code x} or {@code y} is null + * @see #bootstrap(double[], double[], int, boolean) + */ + public double kolmogorovSmirnovTest(double[] x, double[] y, boolean strict) { + final long lengthProduct = (long) x.length * y.length; + double[] xa = null; + double[] ya = null; + if (lengthProduct < LARGE_SAMPLE_PRODUCT && hasTies(x,y)) { + xa = MathArrays.copyOf(x); + ya = MathArrays.copyOf(y); + fixTies(xa, ya); + } else { + xa = x; + ya = y; + } + if (lengthProduct < LARGE_SAMPLE_PRODUCT) { + return exactP(kolmogorovSmirnovStatistic(xa, ya), x.length, y.length, strict); + } + return approximateP(kolmogorovSmirnovStatistic(x, y), x.length, y.length); + } + + /** + * Computes the <i>p-value</i>, or <i>observed significance level</i>, of a two-sample <a + * href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov test</a> + * evaluating the null hypothesis that {@code x} and {@code y} are samples drawn from the same + * probability distribution. Assumes the strict form of the inequality used to compute the + * p-value. See {@link #kolmogorovSmirnovTest(RealDistribution, double[], boolean)}. + * + * @param x first sample dataset + * @param y second sample dataset + * @return p-value associated with the null hypothesis that {@code x} and {@code y} represent + * samples from the same distribution + * @throws InsufficientDataException if either {@code x} or {@code y} does not have length at + * least 2 + * @throws NullArgumentException if either {@code x} or {@code y} is null + */ + public double kolmogorovSmirnovTest(double[] x, double[] y) { + return kolmogorovSmirnovTest(x, y, true); + } + + /** + * Computes the two-sample Kolmogorov-Smirnov test statistic, \(D_{n,m}=\sup_x |F_n(x)-F_m(x)|\) + * where \(n\) is the length of {@code x}, \(m\) is the length of {@code y}, \(F_n\) is the + * empirical distribution that puts mass \(1/n\) at each of the values in {@code x} and \(F_m\) + * is the empirical distribution of the {@code y} values. + * + * @param x first sample + * @param y second sample + * @return test statistic \(D_{n,m}\) used to evaluate the null hypothesis that {@code x} and + * {@code y} represent samples from the same underlying distribution + * @throws InsufficientDataException if either {@code x} or {@code y} does not have length at + * least 2 + * @throws NullArgumentException if either {@code x} or {@code y} is null + */ + public double kolmogorovSmirnovStatistic(double[] x, double[] y) { + return integralKolmogorovSmirnovStatistic(x, y)/((double)(x.length * (long)y.length)); + } + + /** + * Computes the two-sample Kolmogorov-Smirnov test statistic, \(D_{n,m}=\sup_x |F_n(x)-F_m(x)|\) + * where \(n\) is the length of {@code x}, \(m\) is the length of {@code y}, \(F_n\) is the + * empirical distribution that puts mass \(1/n\) at each of the values in {@code x} and \(F_m\) + * is the empirical distribution of the {@code y} values. Finally \(n m D_{n,m}\) is returned + * as long value. + * + * @param x first sample + * @param y second sample + * @return test statistic \(n m D_{n,m}\) used to evaluate the null hypothesis that {@code x} and + * {@code y} represent samples from the same underlying distribution + * @throws InsufficientDataException if either {@code x} or {@code y} does not have length at + * least 2 + * @throws NullArgumentException if either {@code x} or {@code y} is null + */ + private long integralKolmogorovSmirnovStatistic(double[] x, double[] y) { + checkArray(x); + checkArray(y); + // Copy and sort the sample arrays + final double[] sx = MathArrays.copyOf(x); + final double[] sy = MathArrays.copyOf(y); + Arrays.sort(sx); + Arrays.sort(sy); + final int n = sx.length; + final int m = sy.length; + + int rankX = 0; + int rankY = 0; + long curD = 0l; + + // Find the max difference between cdf_x and cdf_y + long supD = 0l; + do { + double z = Double.compare(sx[rankX], sy[rankY]) <= 0 ? sx[rankX] : sy[rankY]; + while(rankX < n && Double.compare(sx[rankX], z) == 0) { + rankX += 1; + curD += m; + } + while(rankY < m && Double.compare(sy[rankY], z) == 0) { + rankY += 1; + curD -= n; + } + if (curD > supD) { + supD = curD; + } + else if (-curD > supD) { + supD = -curD; + } + } while(rankX < n && rankY < m); + return supD; + } + + /** + * Computes the <i>p-value</i>, or <i>observed significance level</i>, of a one-sample <a + * href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov test</a> + * evaluating the null hypothesis that {@code data} conforms to {@code distribution}. + * + * @param distribution reference distribution + * @param data sample being being evaluated + * @return the p-value associated with the null hypothesis that {@code data} is a sample from + * {@code distribution} + * @throws InsufficientDataException if {@code data} does not have length at least 2 + * @throws NullArgumentException if {@code data} is null + */ + public double kolmogorovSmirnovTest(RealDistribution distribution, double[] data) { + return kolmogorovSmirnovTest(distribution, data, false); + } + + /** + * Performs a <a href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov + * test</a> evaluating the null hypothesis that {@code data} conforms to {@code distribution}. + * + * @param distribution reference distribution + * @param data sample being being evaluated + * @param alpha significance level of the test + * @return true iff the null hypothesis that {@code data} is a sample from {@code distribution} + * can be rejected with confidence 1 - {@code alpha} + * @throws InsufficientDataException if {@code data} does not have length at least 2 + * @throws NullArgumentException if {@code data} is null + */ + public boolean kolmogorovSmirnovTest(RealDistribution distribution, double[] data, double alpha) { + if ((alpha <= 0) || (alpha > 0.5)) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, alpha, 0, 0.5); + } + return kolmogorovSmirnovTest(distribution, data) < alpha; + } + + /** + * Estimates the <i>p-value</i> of a two-sample + * <a href="http://en.wikipedia.org/wiki/Kolmogorov-Smirnov_test"> Kolmogorov-Smirnov test</a> + * evaluating the null hypothesis that {@code x} and {@code y} are samples drawn from the same + * probability distribution. This method estimates the p-value by repeatedly sampling sets of size + * {@code x.length} and {@code y.length} from the empirical distribution of the combined sample. + * When {@code strict} is true, this is equivalent to the algorithm implemented in the R function + * {@code ks.boot}, described in <pre> + * Jasjeet S. Sekhon. 2011. 'Multivariate and Propensity Score Matching + * Software with Automated Balance Optimization: The Matching package for R.' + * Journal of Statistical Software, 42(7): 1-52. + * </pre> + * @param x first sample + * @param y second sample + * @param iterations number of bootstrap resampling iterations + * @param strict whether or not the null hypothesis is expressed as a strict inequality + * @return estimated p-value + */ + public double bootstrap(double[] x, double[] y, int iterations, boolean strict) { + final int xLength = x.length; + final int yLength = y.length; + final double[] combined = new double[xLength + yLength]; + System.arraycopy(x, 0, combined, 0, xLength); + System.arraycopy(y, 0, combined, xLength, yLength); + final EnumeratedRealDistribution dist = new EnumeratedRealDistribution(rng, combined); + final long d = integralKolmogorovSmirnovStatistic(x, y); + int greaterCount = 0; + int equalCount = 0; + double[] curX; + double[] curY; + long curD; + for (int i = 0; i < iterations; i++) { + curX = dist.sample(xLength); + curY = dist.sample(yLength); + curD = integralKolmogorovSmirnovStatistic(curX, curY); + if (curD > d) { + greaterCount++; + } else if (curD == d) { + equalCount++; + } + } + return strict ? greaterCount / (double) iterations : + (greaterCount + equalCount) / (double) iterations; + } + + /** + * Computes {@code bootstrap(x, y, iterations, true)}. + * This is equivalent to ks.boot(x,y, nboots=iterations) using the R Matching + * package function. See #bootstrap(double[], double[], int, boolean). + * + * @param x first sample + * @param y second sample + * @param iterations number of bootstrap resampling iterations + * @return estimated p-value + */ + public double bootstrap(double[] x, double[] y, int iterations) { + return bootstrap(x, y, iterations, true); + } + + /** + * Calculates \(P(D_n < d)\) using the method described in [1] with quick decisions for extreme + * values given in [2] (see above). The result is not exact as with + * {@link #cdfExact(double, int)} because calculations are based on + * {@code double} rather than {@link org.apache.commons.math3.fraction.BigFraction}. + * + * @param d statistic + * @param n sample size + * @return \(P(D_n < d)\) + * @throws MathArithmeticException if algorithm fails to convert {@code h} to a + * {@link org.apache.commons.math3.fraction.BigFraction} in expressing {@code d} as \((k + * - h) / m\) for integer {@code k, m} and \(0 \le h < 1\) + */ + public double cdf(double d, int n) + throws MathArithmeticException { + return cdf(d, n, false); + } + + /** + * Calculates {@code P(D_n < d)}. The result is exact in the sense that BigFraction/BigReal is + * used everywhere at the expense of very slow execution time. Almost never choose this in real + * applications unless you are very sure; this is almost solely for verification purposes. + * Normally, you would choose {@link #cdf(double, int)}. See the class + * javadoc for definitions and algorithm description. + * + * @param d statistic + * @param n sample size + * @return \(P(D_n < d)\) + * @throws MathArithmeticException if the algorithm fails to convert {@code h} to a + * {@link org.apache.commons.math3.fraction.BigFraction} in expressing {@code d} as \((k + * - h) / m\) for integer {@code k, m} and \(0 \le h < 1\) + */ + public double cdfExact(double d, int n) + throws MathArithmeticException { + return cdf(d, n, true); + } + + /** + * Calculates {@code P(D_n < d)} using method described in [1] with quick decisions for extreme + * values given in [2] (see above). + * + * @param d statistic + * @param n sample size + * @param exact whether the probability should be calculated exact using + * {@link org.apache.commons.math3.fraction.BigFraction} everywhere at the expense of + * very slow execution time, or if {@code double} should be used convenient places to + * gain speed. Almost never choose {@code true} in real applications unless you are very + * sure; {@code true} is almost solely for verification purposes. + * @return \(P(D_n < d)\) + * @throws MathArithmeticException if algorithm fails to convert {@code h} to a + * {@link org.apache.commons.math3.fraction.BigFraction} in expressing {@code d} as \((k + * - h) / m\) for integer {@code k, m} and \(0 \le h < 1\). + */ + public double cdf(double d, int n, boolean exact) + throws MathArithmeticException { + + final double ninv = 1 / ((double) n); + final double ninvhalf = 0.5 * ninv; + + if (d <= ninvhalf) { + return 0; + } else if (ninvhalf < d && d <= ninv) { + double res = 1; + final double f = 2 * d - ninv; + // n! f^n = n*f * (n-1)*f * ... * 1*x + for (int i = 1; i <= n; ++i) { + res *= i * f; + } + return res; + } else if (1 - ninv <= d && d < 1) { + return 1 - 2 * Math.pow(1 - d, n); + } else if (1 <= d) { + return 1; + } + if (exact) { + return exactK(d, n); + } + if (n <= 140) { + return roundedK(d, n); + } + return pelzGood(d, n); + } + + /** + * Calculates the exact value of {@code P(D_n < d)} using the method described in [1] (reference + * in class javadoc above) and {@link org.apache.commons.math3.fraction.BigFraction} (see + * above). + * + * @param d statistic + * @param n sample size + * @return the two-sided probability of \(P(D_n < d)\) + * @throws MathArithmeticException if algorithm fails to convert {@code h} to a + * {@link org.apache.commons.math3.fraction.BigFraction} in expressing {@code d} as \((k + * - h) / m\) for integer {@code k, m} and \(0 \le h < 1\). + */ + private double exactK(double d, int n) + throws MathArithmeticException { + + final int k = (int) Math.ceil(n * d); + + final FieldMatrix<BigFraction> H = this.createExactH(d, n); + final FieldMatrix<BigFraction> Hpower = H.power(n); + + BigFraction pFrac = Hpower.getEntry(k - 1, k - 1); + + for (int i = 1; i <= n; ++i) { + pFrac = pFrac.multiply(i).divide(n); + } + + /* + * BigFraction.doubleValue converts numerator to double and the denominator to double and + * divides afterwards. That gives NaN quite easy. This does not (scale is the number of + * digits): + */ + return pFrac.bigDecimalValue(20, BigDecimal.ROUND_HALF_UP).doubleValue(); + } + + /** + * Calculates {@code P(D_n < d)} using method described in [1] and doubles (see above). + * + * @param d statistic + * @param n sample size + * @return \(P(D_n < d)\) + */ + private double roundedK(double d, int n) { + + final int k = (int) Math.ceil(n * d); + final RealMatrix H = this.createRoundedH(d, n); + final RealMatrix Hpower = H.power(n); + + double pFrac = Hpower.getEntry(k - 1, k - 1); + for (int i = 1; i <= n; ++i) { + pFrac *= (double) i / (double) n; + } + + return pFrac; + } + + /** + * Computes the Pelz-Good approximation for \(P(D_n < d)\) as described in [2] in the class javadoc. + * + * @param d value of d-statistic (x in [2]) + * @param n sample size + * @return \(P(D_n < d)\) + * @since 3.4 + */ + public double pelzGood(double d, int n) { + // Change the variable since approximation is for the distribution evaluated at d / sqrt(n) + final double sqrtN = FastMath.sqrt(n); + final double z = d * sqrtN; + final double z2 = d * d * n; + final double z4 = z2 * z2; + final double z6 = z4 * z2; + final double z8 = z4 * z4; + + // Eventual return value + double ret = 0; + + // Compute K_0(z) + double sum = 0; + double increment = 0; + double kTerm = 0; + double z2Term = MathUtils.PI_SQUARED / (8 * z2); + int k = 1; + for (; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm = 2 * k - 1; + increment = FastMath.exp(-z2Term * kTerm * kTerm); + sum += increment; + if (increment <= PG_SUM_RELATIVE_ERROR * sum) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + ret = sum * FastMath.sqrt(2 * FastMath.PI) / z; + + // K_1(z) + // Sum is -inf to inf, but k term is always (k + 1/2) ^ 2, so really have + // twice the sum from k = 0 to inf (k = -1 is same as 0, -2 same as 1, ...) + final double twoZ2 = 2 * z2; + sum = 0; + kTerm = 0; + double kTerm2 = 0; + for (k = 0; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm = k + 0.5; + kTerm2 = kTerm * kTerm; + increment = (MathUtils.PI_SQUARED * kTerm2 - z2) * FastMath.exp(-MathUtils.PI_SQUARED * kTerm2 / twoZ2); + sum += increment; + if (FastMath.abs(increment) < PG_SUM_RELATIVE_ERROR * FastMath.abs(sum)) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + final double sqrtHalfPi = FastMath.sqrt(FastMath.PI / 2); + // Instead of doubling sum, divide by 3 instead of 6 + ret += sum * sqrtHalfPi / (3 * z4 * sqrtN); + + // K_2(z) + // Same drill as K_1, but with two doubly infinite sums, all k terms are even powers. + final double z4Term = 2 * z4; + final double z6Term = 6 * z6; + z2Term = 5 * z2; + final double pi4 = MathUtils.PI_SQUARED * MathUtils.PI_SQUARED; + sum = 0; + kTerm = 0; + kTerm2 = 0; + for (k = 0; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm = k + 0.5; + kTerm2 = kTerm * kTerm; + increment = (z6Term + z4Term + MathUtils.PI_SQUARED * (z4Term - z2Term) * kTerm2 + + pi4 * (1 - twoZ2) * kTerm2 * kTerm2) * FastMath.exp(-MathUtils.PI_SQUARED * kTerm2 / twoZ2); + sum += increment; + if (FastMath.abs(increment) < PG_SUM_RELATIVE_ERROR * FastMath.abs(sum)) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + double sum2 = 0; + kTerm2 = 0; + for (k = 1; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm2 = k * k; + increment = MathUtils.PI_SQUARED * kTerm2 * FastMath.exp(-MathUtils.PI_SQUARED * kTerm2 / twoZ2); + sum2 += increment; + if (FastMath.abs(increment) < PG_SUM_RELATIVE_ERROR * FastMath.abs(sum2)) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + // Again, adjust coefficients instead of doubling sum, sum2 + ret += (sqrtHalfPi / n) * (sum / (36 * z2 * z2 * z2 * z) - sum2 / (18 * z2 * z)); + + // K_3(z) One more time with feeling - two doubly infinite sums, all k powers even. + // Multiply coefficient denominators by 2, so omit doubling sums. + final double pi6 = pi4 * MathUtils.PI_SQUARED; + sum = 0; + double kTerm4 = 0; + double kTerm6 = 0; + for (k = 0; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm = k + 0.5; + kTerm2 = kTerm * kTerm; + kTerm4 = kTerm2 * kTerm2; + kTerm6 = kTerm4 * kTerm2; + increment = (pi6 * kTerm6 * (5 - 30 * z2) + pi4 * kTerm4 * (-60 * z2 + 212 * z4) + + MathUtils.PI_SQUARED * kTerm2 * (135 * z4 - 96 * z6) - 30 * z6 - 90 * z8) * + FastMath.exp(-MathUtils.PI_SQUARED * kTerm2 / twoZ2); + sum += increment; + if (FastMath.abs(increment) < PG_SUM_RELATIVE_ERROR * FastMath.abs(sum)) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + sum2 = 0; + for (k = 1; k < MAXIMUM_PARTIAL_SUM_COUNT; k++) { + kTerm2 = k * k; + kTerm4 = kTerm2 * kTerm2; + increment = (-pi4 * kTerm4 + 3 * MathUtils.PI_SQUARED * kTerm2 * z2) * + FastMath.exp(-MathUtils.PI_SQUARED * kTerm2 / twoZ2); + sum2 += increment; + if (FastMath.abs(increment) < PG_SUM_RELATIVE_ERROR * FastMath.abs(sum2)) { + break; + } + } + if (k == MAXIMUM_PARTIAL_SUM_COUNT) { + throw new TooManyIterationsException(MAXIMUM_PARTIAL_SUM_COUNT); + } + return ret + (sqrtHalfPi / (sqrtN * n)) * (sum / (3240 * z6 * z4) + + + sum2 / (108 * z6)); + + } + + /*** + * Creates {@code H} of size {@code m x m} as described in [1] (see above). + * + * @param d statistic + * @param n sample size + * @return H matrix + * @throws NumberIsTooLargeException if fractional part is greater than 1 + * @throws FractionConversionException if algorithm fails to convert {@code h} to a + * {@link org.apache.commons.math3.fraction.BigFraction} in expressing {@code d} as \((k + * - h) / m\) for integer {@code k, m} and \(0 <= h < 1\). + */ + private FieldMatrix<BigFraction> createExactH(double d, int n) + throws NumberIsTooLargeException, FractionConversionException { + + final int k = (int) Math.ceil(n * d); + final int m = 2 * k - 1; + final double hDouble = k - n * d; + if (hDouble >= 1) { + throw new NumberIsTooLargeException(hDouble, 1.0, false); + } + BigFraction h = null; + try { + h = new BigFraction(hDouble, 1.0e-20, 10000); + } catch (final FractionConversionException e1) { + try { + h = new BigFraction(hDouble, 1.0e-10, 10000); + } catch (final FractionConversionException e2) { + h = new BigFraction(hDouble, 1.0e-5, 10000); + } + } + final BigFraction[][] Hdata = new BigFraction[m][m]; + + /* + * Start by filling everything with either 0 or 1. + */ + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) { + if (i - j + 1 < 0) { + Hdata[i][j] = BigFraction.ZERO; + } else { + Hdata[i][j] = BigFraction.ONE; + } + } + } + + /* + * Setting up power-array to avoid calculating the same value twice: hPowers[0] = h^1 ... + * hPowers[m-1] = h^m + */ + final BigFraction[] hPowers = new BigFraction[m]; + hPowers[0] = h; + for (int i = 1; i < m; ++i) { + hPowers[i] = h.multiply(hPowers[i - 1]); + } + + /* + * First column and last row has special values (each other reversed). + */ + for (int i = 0; i < m; ++i) { + Hdata[i][0] = Hdata[i][0].subtract(hPowers[i]); + Hdata[m - 1][i] = Hdata[m - 1][i].subtract(hPowers[m - i - 1]); + } + + /* + * [1] states: "For 1/2 < h < 1 the bottom left element of the matrix should be (1 - 2*h^m + + * (2h - 1)^m )/m!" Since 0 <= h < 1, then if h > 1/2 is sufficient to check: + */ + if (h.compareTo(BigFraction.ONE_HALF) == 1) { + Hdata[m - 1][0] = Hdata[m - 1][0].add(h.multiply(2).subtract(1).pow(m)); + } + + /* + * Aside from the first column and last row, the (i, j)-th element is 1/(i - j + 1)! if i - + * j + 1 >= 0, else 0. 1's and 0's are already put, so only division with (i - j + 1)! is + * needed in the elements that have 1's. There is no need to calculate (i - j + 1)! and then + * divide - small steps avoid overflows. Note that i - j + 1 > 0 <=> i + 1 > j instead of + * j'ing all the way to m. Also note that it is started at g = 2 because dividing by 1 isn't + * really necessary. + */ + for (int i = 0; i < m; ++i) { + for (int j = 0; j < i + 1; ++j) { + if (i - j + 1 > 0) { + for (int g = 2; g <= i - j + 1; ++g) { + Hdata[i][j] = Hdata[i][j].divide(g); + } + } + } + } + return new Array2DRowFieldMatrix<BigFraction>(BigFractionField.getInstance(), Hdata); + } + + /*** + * Creates {@code H} of size {@code m x m} as described in [1] (see above) + * using double-precision. + * + * @param d statistic + * @param n sample size + * @return H matrix + * @throws NumberIsTooLargeException if fractional part is greater than 1 + */ + private RealMatrix createRoundedH(double d, int n) + throws NumberIsTooLargeException { + + final int k = (int) Math.ceil(n * d); + final int m = 2 * k - 1; + final double h = k - n * d; + if (h >= 1) { + throw new NumberIsTooLargeException(h, 1.0, false); + } + final double[][] Hdata = new double[m][m]; + + /* + * Start by filling everything with either 0 or 1. + */ + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) { + if (i - j + 1 < 0) { + Hdata[i][j] = 0; + } else { + Hdata[i][j] = 1; + } + } + } + + /* + * Setting up power-array to avoid calculating the same value twice: hPowers[0] = h^1 ... + * hPowers[m-1] = h^m + */ + final double[] hPowers = new double[m]; + hPowers[0] = h; + for (int i = 1; i < m; ++i) { + hPowers[i] = h * hPowers[i - 1]; + } + + /* + * First column and last row has special values (each other reversed). + */ + for (int i = 0; i < m; ++i) { + Hdata[i][0] = Hdata[i][0] - hPowers[i]; + Hdata[m - 1][i] -= hPowers[m - i - 1]; + } + + /* + * [1] states: "For 1/2 < h < 1 the bottom left element of the matrix should be (1 - 2*h^m + + * (2h - 1)^m )/m!" Since 0 <= h < 1, then if h > 1/2 is sufficient to check: + */ + if (Double.compare(h, 0.5) > 0) { + Hdata[m - 1][0] += FastMath.pow(2 * h - 1, m); + } + + /* + * Aside from the first column and last row, the (i, j)-th element is 1/(i - j + 1)! if i - + * j + 1 >= 0, else 0. 1's and 0's are already put, so only division with (i - j + 1)! is + * needed in the elements that have 1's. There is no need to calculate (i - j + 1)! and then + * divide - small steps avoid overflows. Note that i - j + 1 > 0 <=> i + 1 > j instead of + * j'ing all the way to m. Also note that it is started at g = 2 because dividing by 1 isn't + * really necessary. + */ + for (int i = 0; i < m; ++i) { + for (int j = 0; j < i + 1; ++j) { + if (i - j + 1 > 0) { + for (int g = 2; g <= i - j + 1; ++g) { + Hdata[i][j] /= g; + } + } + } + } + return MatrixUtils.createRealMatrix(Hdata); + } + + /** + * Verifies that {@code array} has length at least 2. + * + * @param array array to test + * @throws NullArgumentException if array is null + * @throws InsufficientDataException if array is too short + */ + private void checkArray(double[] array) { + if (array == null) { + throw new NullArgumentException(LocalizedFormats.NULL_NOT_ALLOWED); + } + if (array.length < 2) { + throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, array.length, + 2); + } + } + + /** + * Computes \( 1 + 2 \sum_{i=1}^\infty (-1)^i e^{-2 i^2 t^2} \) stopping when successive partial + * sums are within {@code tolerance} of one another, or when {@code maxIterations} partial sums + * have been computed. If the sum does not converge before {@code maxIterations} iterations a + * {@link TooManyIterationsException} is thrown. + * + * @param t argument + * @param tolerance Cauchy criterion for partial sums + * @param maxIterations maximum number of partial sums to compute + * @return Kolmogorov sum evaluated at t + * @throws TooManyIterationsException if the series does not converge + */ + public double ksSum(double t, double tolerance, int maxIterations) { + if (t == 0.0) { + return 0.0; + } + + // TODO: for small t (say less than 1), the alternative expansion in part 3 of [1] + // from class javadoc should be used. + + final double x = -2 * t * t; + int sign = -1; + long i = 1; + double partialSum = 0.5d; + double delta = 1; + while (delta > tolerance && i < maxIterations) { + delta = FastMath.exp(x * i * i); + partialSum += sign * delta; + sign *= -1; + i++; + } + if (i == maxIterations) { + throw new TooManyIterationsException(maxIterations); + } + return partialSum * 2; + } + + /** + * Given a d-statistic in the range [0, 1] and the two sample sizes n and m, + * an integral d-statistic in the range [0, n*m] is calculated, that can be used for + * comparison with other integral d-statistics. Depending whether {@code strict} is + * {@code true} or not, the returned value divided by (n*m) is greater than + * (resp greater than or equal to) the given d value (allowing some tolerance). + * + * @param d a d-statistic in the range [0, 1] + * @param n first sample size + * @param m second sample size + * @param strict whether the returned value divided by (n*m) is allowed to be equal to d + * @return the integral d-statistic in the range [0, n*m] + */ + private static long calculateIntegralD(double d, int n, int m, boolean strict) { + final double tol = 1e-12; // d-values within tol of one another are considered equal + long nm = n * (long)m; + long upperBound = (long)FastMath.ceil((d - tol) * nm); + long lowerBound = (long)FastMath.floor((d + tol) * nm); + if (strict && lowerBound == upperBound) { + return upperBound + 1l; + } + else { + return upperBound; + } + } + + /** + * Computes \(P(D_{n,m} > d)\) if {@code strict} is {@code true}; otherwise \(P(D_{n,m} \ge + * d)\), where \(D_{n,m}\) is the 2-sample Kolmogorov-Smirnov statistic. See + * {@link #kolmogorovSmirnovStatistic(double[], double[])} for the definition of \(D_{n,m}\). + * <p> + * The returned probability is exact, implemented by unwinding the recursive function + * definitions presented in [4] (class javadoc). + * </p> + * + * @param d D-statistic value + * @param n first sample size + * @param m second sample size + * @param strict whether or not the probability to compute is expressed as a strict inequality + * @return probability that a randomly selected m-n partition of m + n generates \(D_{n,m}\) + * greater than (resp. greater than or equal to) {@code d} + */ + public double exactP(double d, int n, int m, boolean strict) { + return 1 - n(m, n, m, n, calculateIntegralD(d, m, n, strict), strict) / + CombinatoricsUtils.binomialCoefficientDouble(n + m, m); + } + + /** + * Uses the Kolmogorov-Smirnov distribution to approximate \(P(D_{n,m} > d)\) where \(D_{n,m}\) + * is the 2-sample Kolmogorov-Smirnov statistic. See + * {@link #kolmogorovSmirnovStatistic(double[], double[])} for the definition of \(D_{n,m}\). + * <p> + * Specifically, what is returned is \(1 - k(d \sqrt{mn / (m + n)})\) where \(k(t) = 1 + 2 + * \sum_{i=1}^\infty (-1)^i e^{-2 i^2 t^2}\). See {@link #ksSum(double, double, int)} for + * details on how convergence of the sum is determined. This implementation passes {@code ksSum} + * {@value #KS_SUM_CAUCHY_CRITERION} as {@code tolerance} and + * {@value #MAXIMUM_PARTIAL_SUM_COUNT} as {@code maxIterations}. + * </p> + * + * @param d D-statistic value + * @param n first sample size + * @param m second sample size + * @return approximate probability that a randomly selected m-n partition of m + n generates + * \(D_{n,m}\) greater than {@code d} + */ + public double approximateP(double d, int n, int m) { + final double dm = m; + final double dn = n; + return 1 - ksSum(d * FastMath.sqrt((dm * dn) / (dm + dn)), + KS_SUM_CAUCHY_CRITERION, MAXIMUM_PARTIAL_SUM_COUNT); + } + + /** + * Fills a boolean array randomly with a fixed number of {@code true} values. + * The method uses a simplified version of the Fisher-Yates shuffle algorithm. + * By processing first the {@code true} values followed by the remaining {@code false} values + * less random numbers need to be generated. The method is optimized for the case + * that the number of {@code true} values is larger than or equal to the number of + * {@code false} values. + * + * @param b boolean array + * @param numberOfTrueValues number of {@code true} values the boolean array should finally have + * @param rng random data generator + */ + static void fillBooleanArrayRandomlyWithFixedNumberTrueValues(final boolean[] b, final int numberOfTrueValues, final RandomGenerator rng) { + Arrays.fill(b, true); + for (int k = numberOfTrueValues; k < b.length; k++) { + final int r = rng.nextInt(k + 1); + b[(b[r]) ? r : k] = false; + } + } + + /** + * Uses Monte Carlo simulation to approximate \(P(D_{n,m} > d)\) where \(D_{n,m}\) is the + * 2-sample Kolmogorov-Smirnov statistic. See + * {@link #kolmogorovSmirnovStatistic(double[], double[])} for the definition of \(D_{n,m}\). + * <p> + * The simulation generates {@code iterations} random partitions of {@code m + n} into an + * {@code n} set and an {@code m} set, computing \(D_{n,m}\) for each partition and returning + * the proportion of values that are greater than {@code d}, or greater than or equal to + * {@code d} if {@code strict} is {@code false}. + * </p> + * + * @param d D-statistic value + * @param n first sample size + * @param m second sample size + * @param iterations number of random partitions to generate + * @param strict whether or not the probability to compute is expressed as a strict inequality + * @return proportion of randomly generated m-n partitions of m + n that result in \(D_{n,m}\) + * greater than (resp. greater than or equal to) {@code d} + */ + public double monteCarloP(final double d, final int n, final int m, final boolean strict, + final int iterations) { + return integralMonteCarloP(calculateIntegralD(d, n, m, strict), n, m, iterations); + } + + /** + * Uses Monte Carlo simulation to approximate \(P(D_{n,m} >= d/(n*m))\) where \(D_{n,m}\) is the + * 2-sample Kolmogorov-Smirnov statistic. + * <p> + * Here d is the D-statistic represented as long value. + * The real D-statistic is obtained by dividing d by n*m. + * See also {@link #monteCarloP(double, int, int, boolean, int)}. + * + * @param d integral D-statistic + * @param n first sample size + * @param m second sample size + * @param iterations number of random partitions to generate + * @return proportion of randomly generated m-n partitions of m + n that result in \(D_{n,m}\) + * greater than or equal to {@code d/(n*m))} + */ + private double integralMonteCarloP(final long d, final int n, final int m, final int iterations) { + + // ensure that nn is always the max of (n, m) to require fewer random numbers + final int nn = FastMath.max(n, m); + final int mm = FastMath.min(n, m); + final int sum = nn + mm; + + int tail = 0; + final boolean b[] = new boolean[sum]; + for (int i = 0; i < iterations; i++) { + fillBooleanArrayRandomlyWithFixedNumberTrueValues(b, nn, rng); + long curD = 0l; + for(int j = 0; j < b.length; ++j) { + if (b[j]) { + curD += mm; + if (curD >= d) { + tail++; + break; + } + } else { + curD -= nn; + if (curD <= -d) { + tail++; + break; + } + } + } + } + return (double) tail / iterations; + } + + /** + * If there are no ties in the combined dataset formed from x and y, this + * method is a no-op. If there are ties, a uniform random deviate in + * (-minDelta / 2, minDelta / 2) - {0} is added to each value in x and y, where + * minDelta is the minimum difference between unequal values in the combined + * sample. A fixed seed is used to generate the jitter, so repeated activations + * with the same input arrays result in the same values. + * + * NOTE: if there are ties in the data, this method overwrites the data in + * x and y with the jittered values. + * + * @param x first sample + * @param y second sample + */ + private static void fixTies(double[] x, double[] y) { + final double[] values = MathArrays.unique(MathArrays.concatenate(x,y)); + if (values.length == x.length + y.length) { + return; // There are no ties + } + + // Find the smallest difference between values, or 1 if all values are the same + double minDelta = 1; + double prev = values[0]; + double delta = 1; + for (int i = 1; i < values.length; i++) { + delta = prev - values[i]; + if (delta < minDelta) { + minDelta = delta; + } + prev = values[i]; + } + minDelta /= 2; + + // Add jitter using a fixed seed (so same arguments always give same results), + // low-initialization-overhead generator + final RealDistribution dist = + new UniformRealDistribution(new JDKRandomGenerator(100), -minDelta, minDelta); + + // It is theoretically possible that jitter does not break ties, so repeat + // until all ties are gone. Bound the loop and throw MIE if bound is exceeded. + int ct = 0; + boolean ties = true; + do { + jitter(x, dist); + jitter(y, dist); + ties = hasTies(x, y); + ct++; + } while (ties && ct < 1000); + if (ties) { + throw new MathInternalError(); // Should never happen + } + } + + /** + * Returns true iff there are ties in the combined sample + * formed from x and y. + * + * @param x first sample + * @param y second sample + * @return true if x and y together contain ties + */ + private static boolean hasTies(double[] x, double[] y) { + final HashSet<Double> values = new HashSet<Double>(); + for (int i = 0; i < x.length; i++) { + if (!values.add(x[i])) { + return true; + } + } + for (int i = 0; i < y.length; i++) { + if (!values.add(y[i])) { + return true; + } + } + return false; + } + + /** + * Adds random jitter to {@code data} using deviates sampled from {@code dist}. + * <p> + * Note that jitter is applied in-place - i.e., the array + * values are overwritten with the result of applying jitter.</p> + * + * @param data input/output data array - entries overwritten by the method + * @param dist probability distribution to sample for jitter values + * @throws NullPointerException if either of the parameters is null + */ + private static void jitter(double[] data, RealDistribution dist) { + for (int i = 0; i < data.length; i++) { + data[i] += dist.sample(); + } + } + + /** + * The function C(i, j) defined in [4] (class javadoc), formula (5.5). + * defined to return 1 if |i/n - j/m| <= c; 0 otherwise. Here c is scaled up + * and recoded as a long to avoid rounding errors in comparison tests, so what + * is actually tested is |im - jn| <= cmn. + * + * @param i first path parameter + * @param j second path paramter + * @param m first sample size + * @param n second sample size + * @param cmn integral D-statistic (see {@link #calculateIntegralD(double, int, int, boolean)}) + * @param strict whether or not the null hypothesis uses strict inequality + * @return C(i,j) for given m, n, c + */ + private static int c(int i, int j, int m, int n, long cmn, boolean strict) { + if (strict) { + return FastMath.abs(i*(long)n - j*(long)m) <= cmn ? 1 : 0; + } + return FastMath.abs(i*(long)n - j*(long)m) < cmn ? 1 : 0; + } + + /** + * The function N(i, j) defined in [4] (class javadoc). + * Returns the number of paths over the lattice {(i,j) : 0 <= i <= n, 0 <= j <= m} + * from (0,0) to (i,j) satisfying C(h,k, m, n, c) = 1 for each (h,k) on the path. + * The return value is integral, but subject to overflow, so it is maintained and + * returned as a double. + * + * @param i first path parameter + * @param j second path parameter + * @param m first sample size + * @param n second sample size + * @param cnm integral D-statistic (see {@link #calculateIntegralD(double, int, int, boolean)}) + * @param strict whether or not the null hypothesis uses strict inequality + * @return number or paths to (i, j) from (0,0) representing D-values as large as c for given m, n + */ + private static double n(int i, int j, int m, int n, long cnm, boolean strict) { + /* + * Unwind the recursive definition given in [4]. + * Compute n(1,1), n(1,2)...n(2,1), n(2,2)... up to n(i,j), one row at a time. + * When n(i,*) are being computed, lag[] holds the values of n(i - 1, *). + */ + final double[] lag = new double[n]; + double last = 0; + for (int k = 0; k < n; k++) { + lag[k] = c(0, k + 1, m, n, cnm, strict); + } + for (int k = 1; k <= i; k++) { + last = c(k, 0, m, n, cnm, strict); + for (int l = 1; l <= j; l++) { + lag[l - 1] = c(k, l, m, n, cnm, strict) * (last + lag[l - 1]); + last = lag[l - 1]; + } + } + return last; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/MannWhitneyUTest.java b/src/main/java/org/apache/commons/math3/stat/inference/MannWhitneyUTest.java new file mode 100644 index 0000000..82fddb3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/MannWhitneyUTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.stat.ranking.NaNStrategy; +import org.apache.commons.math3.stat.ranking.NaturalRanking; +import org.apache.commons.math3.stat.ranking.TiesStrategy; +import org.apache.commons.math3.util.FastMath; + +/** + * An implementation of the Mann-Whitney U test (also called Wilcoxon rank-sum test). + * + */ +public class MannWhitneyUTest { + + /** Ranking algorithm. */ + private NaturalRanking naturalRanking; + + /** + * Create a test instance using where NaN's are left in place and ties get + * the average of applicable ranks. Use this unless you are very sure of + * what you are doing. + */ + public MannWhitneyUTest() { + naturalRanking = new NaturalRanking(NaNStrategy.FIXED, + TiesStrategy.AVERAGE); + } + + /** + * Create a test instance using the given strategies for NaN's and ties. + * Only use this if you are sure of what you are doing. + * + * @param nanStrategy + * specifies the strategy that should be used for Double.NaN's + * @param tiesStrategy + * specifies the strategy that should be used for ties + */ + public MannWhitneyUTest(final NaNStrategy nanStrategy, + final TiesStrategy tiesStrategy) { + naturalRanking = new NaturalRanking(nanStrategy, tiesStrategy); + } + + /** + * Ensures that the provided arrays fulfills the assumptions. + * + * @param x first sample + * @param y second sample + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + */ + private void ensureDataConformance(final double[] x, final double[] y) + throws NullArgumentException, NoDataException { + + if (x == null || + y == null) { + throw new NullArgumentException(); + } + if (x.length == 0 || + y.length == 0) { + throw new NoDataException(); + } + } + + /** Concatenate the samples into one array. + * @param x first sample + * @param y second sample + * @return concatenated array + */ + private double[] concatenateSamples(final double[] x, final double[] y) { + final double[] z = new double[x.length + y.length]; + + System.arraycopy(x, 0, z, 0, x.length); + System.arraycopy(y, 0, z, x.length, y.length); + + return z; + } + + /** + * Computes the <a + * href="http://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U"> Mann-Whitney + * U statistic</a> comparing mean for two independent samples possibly of + * different length. + * <p> + * This statistic can be used to perform a Mann-Whitney U test evaluating + * the null hypothesis that the two independent samples has equal mean. + * </p> + * <p> + * Let X<sub>i</sub> denote the i'th individual of the first sample and + * Y<sub>j</sub> the j'th individual in the second sample. Note that the + * samples would often have different length. + * </p> + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>All observations in the two samples are independent.</li> + * <li>The observations are at least ordinal (continuous are also ordinal).</li> + * </ul> + * </p> + * + * @param x the first sample + * @param y the second sample + * @return Mann-Whitney U statistic (maximum of U<sup>x</sup> and U<sup>y</sup>) + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + */ + public double mannWhitneyU(final double[] x, final double[] y) + throws NullArgumentException, NoDataException { + + ensureDataConformance(x, y); + + final double[] z = concatenateSamples(x, y); + final double[] ranks = naturalRanking.rank(z); + + double sumRankX = 0; + + /* + * The ranks for x is in the first x.length entries in ranks because x + * is in the first x.length entries in z + */ + for (int i = 0; i < x.length; ++i) { + sumRankX += ranks[i]; + } + + /* + * U1 = R1 - (n1 * (n1 + 1)) / 2 where R1 is sum of ranks for sample 1, + * e.g. x, n1 is the number of observations in sample 1. + */ + final double U1 = sumRankX - ((long) x.length * (x.length + 1)) / 2; + + /* + * It can be shown that U1 + U2 = n1 * n2 + */ + final double U2 = (long) x.length * y.length - U1; + + return FastMath.max(U1, U2); + } + + /** + * @param Umin smallest Mann-Whitney U value + * @param n1 number of subjects in first sample + * @param n2 number of subjects in second sample + * @return two-sided asymptotic p-value + * @throws ConvergenceException if the p-value can not be computed + * due to a convergence error + * @throws MaxCountExceededException if the maximum number of + * iterations is exceeded + */ + private double calculateAsymptoticPValue(final double Umin, + final int n1, + final int n2) + throws ConvergenceException, MaxCountExceededException { + + /* long multiplication to avoid overflow (double not used due to efficiency + * and to avoid precision loss) + */ + final long n1n2prod = (long) n1 * n2; + + // http://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U#Normal_approximation + final double EU = n1n2prod / 2.0; + final double VarU = n1n2prod * (n1 + n2 + 1) / 12.0; + + final double z = (Umin - EU) / FastMath.sqrt(VarU); + + // No try-catch or advertised exception because args are valid + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final NormalDistribution standardNormal = new NormalDistribution(null, 0, 1); + + return 2 * standardNormal.cumulativeProbability(z); + } + + /** + * Returns the asymptotic <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a <a + * href="http://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U"> Mann-Whitney + * U statistic</a> comparing mean for two independent samples. + * <p> + * Let X<sub>i</sub> denote the i'th individual of the first sample and + * Y<sub>j</sub> the j'th individual in the second sample. Note that the + * samples would often have different length. + * </p> + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>All observations in the two samples are independent.</li> + * <li>The observations are at least ordinal (continuous are also ordinal).</li> + * </ul> + * </p><p> + * Ties give rise to biased variance at the moment. See e.g. <a + * href="http://mlsc.lboro.ac.uk/resources/statistics/Mannwhitney.pdf" + * >http://mlsc.lboro.ac.uk/resources/statistics/Mannwhitney.pdf</a>.</p> + * + * @param x the first sample + * @param y the second sample + * @return asymptotic p-value + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + * @throws ConvergenceException if the p-value can not be computed due to a + * convergence error + * @throws MaxCountExceededException if the maximum number of iterations + * is exceeded + */ + public double mannWhitneyUTest(final double[] x, final double[] y) + throws NullArgumentException, NoDataException, + ConvergenceException, MaxCountExceededException { + + ensureDataConformance(x, y); + + final double Umax = mannWhitneyU(x, y); + + /* + * It can be shown that U1 + U2 = n1 * n2 + */ + final double Umin = (long) x.length * y.length - Umax; + + return calculateAsymptoticPValue(Umin, x.length, y.length); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/OneWayAnova.java b/src/main/java/org/apache/commons/math3/stat/inference/OneWayAnova.java new file mode 100644 index 0000000..d0c5fc1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/OneWayAnova.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import java.util.ArrayList; +import java.util.Collection; + +import org.apache.commons.math3.distribution.FDistribution; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.apache.commons.math3.util.MathUtils; + +/** + * Implements one-way ANOVA (analysis of variance) statistics. + * + * <p> Tests for differences between two or more categories of univariate data + * (for example, the body mass index of accountants, lawyers, doctors and + * computer programmers). When two categories are given, this is equivalent to + * the {@link org.apache.commons.math3.stat.inference.TTest}. + * </p><p> + * Uses the {@link org.apache.commons.math3.distribution.FDistribution + * commons-math F Distribution implementation} to estimate exact p-values.</p> + * <p>This implementation is based on a description at + * http://faculty.vassar.edu/lowry/ch13pt1.html</p> + * <pre> + * Abbreviations: bg = between groups, + * wg = within groups, + * ss = sum squared deviations + * </pre> + * + * @since 1.2 + */ +public class OneWayAnova { + + /** + * Default constructor. + */ + public OneWayAnova() { + } + + /** + * Computes the ANOVA F-value for a collection of <code>double[]</code> + * arrays. + * + * <p><strong>Preconditions</strong>: <ul> + * <li>The categoryData <code>Collection</code> must contain + * <code>double[]</code> arrays.</li> + * <li> There must be at least two <code>double[]</code> arrays in the + * <code>categoryData</code> collection and each of these arrays must + * contain at least two values.</li></ul></p><p> + * This implementation computes the F statistic using the definitional + * formula<pre> + * F = msbg/mswg</pre> + * where<pre> + * msbg = between group mean square + * mswg = within group mean square</pre> + * are as defined <a href="http://faculty.vassar.edu/lowry/ch13pt1.html"> + * here</a></p> + * + * @param categoryData <code>Collection</code> of <code>double[]</code> + * arrays each containing data for one category + * @return Fvalue + * @throws NullArgumentException if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException if the length of the <code>categoryData</code> + * array is less than 2 or a contained <code>double[]</code> array does not have + * at least two values + */ + public double anovaFValue(final Collection<double[]> categoryData) + throws NullArgumentException, DimensionMismatchException { + + AnovaStats a = anovaStats(categoryData); + return a.F; + + } + + /** + * Computes the ANOVA P-value for a collection of <code>double[]</code> + * arrays. + * + * <p><strong>Preconditions</strong>: <ul> + * <li>The categoryData <code>Collection</code> must contain + * <code>double[]</code> arrays.</li> + * <li> There must be at least two <code>double[]</code> arrays in the + * <code>categoryData</code> collection and each of these arrays must + * contain at least two values.</li></ul></p><p> + * This implementation uses the + * {@link org.apache.commons.math3.distribution.FDistribution + * commons-math F Distribution implementation} to estimate the exact + * p-value, using the formula<pre> + * p = 1 - cumulativeProbability(F)</pre> + * where <code>F</code> is the F value and <code>cumulativeProbability</code> + * is the commons-math implementation of the F distribution.</p> + * + * @param categoryData <code>Collection</code> of <code>double[]</code> + * arrays each containing data for one category + * @return Pvalue + * @throws NullArgumentException if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException if the length of the <code>categoryData</code> + * array is less than 2 or a contained <code>double[]</code> array does not have + * at least two values + * @throws ConvergenceException if the p-value can not be computed due to a convergence error + * @throws MaxCountExceededException if the maximum number of iterations is exceeded + */ + public double anovaPValue(final Collection<double[]> categoryData) + throws NullArgumentException, DimensionMismatchException, + ConvergenceException, MaxCountExceededException { + + final AnovaStats a = anovaStats(categoryData); + // No try-catch or advertised exception because args are valid + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final FDistribution fdist = new FDistribution(null, a.dfbg, a.dfwg); + return 1.0 - fdist.cumulativeProbability(a.F); + + } + + /** + * Computes the ANOVA P-value for a collection of {@link SummaryStatistics}. + * + * <p><strong>Preconditions</strong>: <ul> + * <li>The categoryData <code>Collection</code> must contain + * {@link SummaryStatistics}.</li> + * <li> There must be at least two {@link SummaryStatistics} in the + * <code>categoryData</code> collection and each of these statistics must + * contain at least two values.</li></ul></p><p> + * This implementation uses the + * {@link org.apache.commons.math3.distribution.FDistribution + * commons-math F Distribution implementation} to estimate the exact + * p-value, using the formula<pre> + * p = 1 - cumulativeProbability(F)</pre> + * where <code>F</code> is the F value and <code>cumulativeProbability</code> + * is the commons-math implementation of the F distribution.</p> + * + * @param categoryData <code>Collection</code> of {@link SummaryStatistics} + * each containing data for one category + * @param allowOneElementData if true, allow computation for one catagory + * only or for one data element per category + * @return Pvalue + * @throws NullArgumentException if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException if the length of the <code>categoryData</code> + * array is less than 2 or a contained {@link SummaryStatistics} does not have + * at least two values + * @throws ConvergenceException if the p-value can not be computed due to a convergence error + * @throws MaxCountExceededException if the maximum number of iterations is exceeded + * @since 3.2 + */ + public double anovaPValue(final Collection<SummaryStatistics> categoryData, + final boolean allowOneElementData) + throws NullArgumentException, DimensionMismatchException, + ConvergenceException, MaxCountExceededException { + + final AnovaStats a = anovaStats(categoryData, allowOneElementData); + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final FDistribution fdist = new FDistribution(null, a.dfbg, a.dfwg); + return 1.0 - fdist.cumulativeProbability(a.F); + + } + + /** + * This method calls the method that actually does the calculations (except + * P-value). + * + * @param categoryData + * <code>Collection</code> of <code>double[]</code> arrays each + * containing data for one category + * @return computed AnovaStats + * @throws NullArgumentException + * if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException + * if the length of the <code>categoryData</code> array is less + * than 2 or a contained <code>double[]</code> array does not + * contain at least two values + */ + private AnovaStats anovaStats(final Collection<double[]> categoryData) + throws NullArgumentException, DimensionMismatchException { + + MathUtils.checkNotNull(categoryData); + + final Collection<SummaryStatistics> categoryDataSummaryStatistics = + new ArrayList<SummaryStatistics>(categoryData.size()); + + // convert arrays to SummaryStatistics + for (final double[] data : categoryData) { + final SummaryStatistics dataSummaryStatistics = new SummaryStatistics(); + categoryDataSummaryStatistics.add(dataSummaryStatistics); + for (final double val : data) { + dataSummaryStatistics.addValue(val); + } + } + + return anovaStats(categoryDataSummaryStatistics, false); + + } + + /** + * Performs an ANOVA test, evaluating the null hypothesis that there + * is no difference among the means of the data categories. + * + * <p><strong>Preconditions</strong>: <ul> + * <li>The categoryData <code>Collection</code> must contain + * <code>double[]</code> arrays.</li> + * <li> There must be at least two <code>double[]</code> arrays in the + * <code>categoryData</code> collection and each of these arrays must + * contain at least two values.</li> + * <li>alpha must be strictly greater than 0 and less than or equal to 0.5. + * </li></ul></p><p> + * This implementation uses the + * {@link org.apache.commons.math3.distribution.FDistribution + * commons-math F Distribution implementation} to estimate the exact + * p-value, using the formula<pre> + * p = 1 - cumulativeProbability(F)</pre> + * where <code>F</code> is the F value and <code>cumulativeProbability</code> + * is the commons-math implementation of the F distribution.</p> + * <p>True is returned iff the estimated p-value is less than alpha.</p> + * + * @param categoryData <code>Collection</code> of <code>double[]</code> + * arrays each containing data for one category + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with + * confidence 1 - alpha + * @throws NullArgumentException if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException if the length of the <code>categoryData</code> + * array is less than 2 or a contained <code>double[]</code> array does not have + * at least two values + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws ConvergenceException if the p-value can not be computed due to a convergence error + * @throws MaxCountExceededException if the maximum number of iterations is exceeded + */ + public boolean anovaTest(final Collection<double[]> categoryData, + final double alpha) + throws NullArgumentException, DimensionMismatchException, + OutOfRangeException, ConvergenceException, MaxCountExceededException { + + if ((alpha <= 0) || (alpha > 0.5)) { + throw new OutOfRangeException( + LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL, + alpha, 0, 0.5); + } + return anovaPValue(categoryData) < alpha; + + } + + /** + * This method actually does the calculations (except P-value). + * + * @param categoryData <code>Collection</code> of <code>double[]</code> + * arrays each containing data for one category + * @param allowOneElementData if true, allow computation for one catagory + * only or for one data element per category + * @return computed AnovaStats + * @throws NullArgumentException if <code>categoryData</code> is <code>null</code> + * @throws DimensionMismatchException if <code>allowOneElementData</code> is false and the number of + * categories is less than 2 or a contained SummaryStatistics does not contain + * at least two values + */ + private AnovaStats anovaStats(final Collection<SummaryStatistics> categoryData, + final boolean allowOneElementData) + throws NullArgumentException, DimensionMismatchException { + + MathUtils.checkNotNull(categoryData); + + if (!allowOneElementData) { + // check if we have enough categories + if (categoryData.size() < 2) { + throw new DimensionMismatchException(LocalizedFormats.TWO_OR_MORE_CATEGORIES_REQUIRED, + categoryData.size(), 2); + } + + // check if each category has enough data + for (final SummaryStatistics array : categoryData) { + if (array.getN() <= 1) { + throw new DimensionMismatchException(LocalizedFormats.TWO_OR_MORE_VALUES_IN_CATEGORY_REQUIRED, + (int) array.getN(), 2); + } + } + } + + int dfwg = 0; + double sswg = 0; + double totsum = 0; + double totsumsq = 0; + int totnum = 0; + + for (final SummaryStatistics data : categoryData) { + + final double sum = data.getSum(); + final double sumsq = data.getSumsq(); + final int num = (int) data.getN(); + totnum += num; + totsum += sum; + totsumsq += sumsq; + + dfwg += num - 1; + final double ss = sumsq - ((sum * sum) / num); + sswg += ss; + } + + final double sst = totsumsq - ((totsum * totsum) / totnum); + final double ssbg = sst - sswg; + final int dfbg = categoryData.size() - 1; + final double msbg = ssbg / dfbg; + final double mswg = sswg / dfwg; + final double F = msbg / mswg; + + return new AnovaStats(dfbg, dfwg, F); + + } + + /** + Convenience class to pass dfbg,dfwg,F values around within OneWayAnova. + No get/set methods provided. + */ + private static class AnovaStats { + + /** Degrees of freedom in numerator (between groups). */ + private final int dfbg; + + /** Degrees of freedom in denominator (within groups). */ + private final int dfwg; + + /** Statistic. */ + private final double F; + + /** + * Constructor + * @param dfbg degrees of freedom in numerator (between groups) + * @param dfwg degrees of freedom in denominator (within groups) + * @param F statistic + */ + private AnovaStats(int dfbg, int dfwg, double F) { + this.dfbg = dfbg; + this.dfwg = dfwg; + this.F = F; + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/TTest.java b/src/main/java/org/apache/commons/math3/stat/inference/TTest.java new file mode 100644 index 0000000..b0f76f6 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/TTest.java @@ -0,0 +1,1184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.TDistribution; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.stat.StatUtils; +import org.apache.commons.math3.stat.descriptive.StatisticalSummary; +import org.apache.commons.math3.util.FastMath; + +/** + * An implementation for Student's t-tests. + * <p> + * Tests can be:<ul> + * <li>One-sample or two-sample</li> + * <li>One-sided or two-sided</li> + * <li>Paired or unpaired (for two-sample tests)</li> + * <li>Homoscedastic (equal variance assumption) or heteroscedastic + * (for two sample tests)</li> + * <li>Fixed significance level (boolean-valued) or returning p-values. + * </li></ul></p> + * <p> + * Test statistics are available for all tests. Methods including "Test" in + * in their names perform tests, all other methods return t-statistics. Among + * the "Test" methods, <code>double-</code>valued methods return p-values; + * <code>boolean-</code>valued methods perform fixed significance level tests. + * Significance levels are always specified as numbers between 0 and 0.5 + * (e.g. tests at the 95% level use <code>alpha=0.05</code>).</p> + * <p> + * Input to tests can be either <code>double[]</code> arrays or + * {@link StatisticalSummary} instances.</p><p> + * Uses commons-math {@link org.apache.commons.math3.distribution.TDistribution} + * implementation to estimate exact p-values.</p> + * + */ +public class TTest { + /** + * Computes a paired, 2-sample t-statistic based on the data in the input + * arrays. The t-statistic returned is equivalent to what would be returned by + * computing the one-sample t-statistic {@link #t(double, double[])}, with + * <code>mu = 0</code> and the sample array consisting of the (signed) + * differences between corresponding entries in <code>sample1</code> and + * <code>sample2.</code> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The input arrays must have the same length and their common length + * must be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return t statistic + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NoDataException if the arrays are empty + * @throws DimensionMismatchException if the length of the arrays is not equal + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + */ + public double pairedT(final double[] sample1, final double[] sample2) + throws NullArgumentException, NoDataException, + DimensionMismatchException, NumberIsTooSmallException { + + checkSampleData(sample1); + checkSampleData(sample2); + double meanDifference = StatUtils.meanDifference(sample1, sample2); + return t(meanDifference, 0, + StatUtils.varianceDifference(sample1, sample2, meanDifference), + sample1.length); + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i> p-value</i>, associated with a paired, two-sample, two-tailed t-test + * based on the data in the input arrays. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the mean of the paired + * differences is 0 in favor of the two-sided alternative that the mean paired + * difference is not equal to 0. For a one-sided test, divide the returned + * value by 2.</p> + * <p> + * This test is equivalent to a one-sample t-test computed using + * {@link #tTest(double, double[])} with <code>mu = 0</code> and the sample + * array consisting of the signed differences between corresponding elements of + * <code>sample1</code> and <code>sample2.</code></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the p-value depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The input array lengths must be the same and their common length must + * be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return p-value for t-test + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NoDataException if the arrays are empty + * @throws DimensionMismatchException if the length of the arrays is not equal + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double pairedTTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NoDataException, DimensionMismatchException, + NumberIsTooSmallException, MaxCountExceededException { + + double meanDifference = StatUtils.meanDifference(sample1, sample2); + return tTest(meanDifference, 0, + StatUtils.varianceDifference(sample1, sample2, meanDifference), + sample1.length); + + } + + /** + * Performs a paired t-test evaluating the null hypothesis that the + * mean of the paired differences between <code>sample1</code> and + * <code>sample2</code> is 0 in favor of the two-sided alternative that the + * mean paired difference is not equal to 0, with significance level + * <code>alpha</code>. + * <p> + * Returns <code>true</code> iff the null hypothesis can be rejected with + * confidence <code>1 - alpha</code>. To perform a 1-sided test, use + * <code>alpha * 2</code></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The input array lengths must be the same and their common length + * must be at least 2. + * </li> + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with + * confidence 1 - alpha + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NoDataException if the arrays are empty + * @throws DimensionMismatchException if the length of the arrays is not equal + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean pairedTTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NoDataException, DimensionMismatchException, + NumberIsTooSmallException, OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return pairedTTest(sample1, sample2) < alpha; + + } + + /** + * Computes a <a href="http://www.itl.nist.gov/div898/handbook/prc/section2/prc22.htm#formula"> + * t statistic </a> given observed values and a comparison constant. + * <p> + * This statistic can be used to perform a one sample t-test for the mean. + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array length must be at least 2. + * </li></ul></p> + * + * @param mu comparison constant + * @param observed array of values + * @return t statistic + * @throws NullArgumentException if <code>observed</code> is <code>null</code> + * @throws NumberIsTooSmallException if the length of <code>observed</code> is < 2 + */ + public double t(final double mu, final double[] observed) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(observed); + // No try-catch or advertised exception because args have just been checked + return t(StatUtils.mean(observed), mu, StatUtils.variance(observed), + observed.length); + + } + + /** + * Computes a <a href="http://www.itl.nist.gov/div898/handbook/prc/section2/prc22.htm#formula"> + * t statistic </a> to use in comparing the mean of the dataset described by + * <code>sampleStats</code> to <code>mu</code>. + * <p> + * This statistic can be used to perform a one sample t-test for the mean. + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li><code>observed.getN() ≥ 2</code>. + * </li></ul></p> + * + * @param mu comparison constant + * @param sampleStats DescriptiveStatistics holding sample summary statitstics + * @return t statistic + * @throws NullArgumentException if <code>sampleStats</code> is <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + */ + public double t(final double mu, final StatisticalSummary sampleStats) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(sampleStats); + return t(sampleStats.getMean(), mu, sampleStats.getVariance(), + sampleStats.getN()); + + } + + /** + * Computes a 2-sample t statistic, under the hypothesis of equal + * subpopulation variances. To compute a t-statistic without the + * equal variances hypothesis, use {@link #t(double[], double[])}. + * <p> + * This statistic can be used to perform a (homoscedastic) two-sample + * t-test to compare sample means.</p> + * <p> + * The t-statistic is</p> + * <p> + * <code> t = (m1 - m2) / (sqrt(1/n1 +1/n2) sqrt(var))</code> + * </p><p> + * where <strong><code>n1</code></strong> is the size of first sample; + * <strong><code> n2</code></strong> is the size of second sample; + * <strong><code> m1</code></strong> is the mean of first sample; + * <strong><code> m2</code></strong> is the mean of second sample</li> + * </ul> + * and <strong><code>var</code></strong> is the pooled variance estimate: + * </p><p> + * <code>var = sqrt(((n1 - 1)var1 + (n2 - 1)var2) / ((n1-1) + (n2-1)))</code> + * </p><p> + * with <strong><code>var1</code></strong> the variance of the first sample and + * <strong><code>var2</code></strong> the variance of the second sample. + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return t statistic + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + */ + public double homoscedasticT(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(sample1); + checkSampleData(sample2); + // No try-catch or advertised exception because args have just been checked + return homoscedasticT(StatUtils.mean(sample1), StatUtils.mean(sample2), + StatUtils.variance(sample1), StatUtils.variance(sample2), + sample1.length, sample2.length); + + } + + /** + * Computes a 2-sample t statistic, without the hypothesis of equal + * subpopulation variances. To compute a t-statistic assuming equal + * variances, use {@link #homoscedasticT(double[], double[])}. + * <p> + * This statistic can be used to perform a two-sample t-test to compare + * sample means.</p> + * <p> + * The t-statistic is</p> + * <p> + * <code> t = (m1 - m2) / sqrt(var1/n1 + var2/n2)</code> + * </p><p> + * where <strong><code>n1</code></strong> is the size of the first sample + * <strong><code> n2</code></strong> is the size of the second sample; + * <strong><code> m1</code></strong> is the mean of the first sample; + * <strong><code> m2</code></strong> is the mean of the second sample; + * <strong><code> var1</code></strong> is the variance of the first sample; + * <strong><code> var2</code></strong> is the variance of the second sample; + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return t statistic + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + */ + public double t(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(sample1); + checkSampleData(sample2); + // No try-catch or advertised exception because args have just been checked + return t(StatUtils.mean(sample1), StatUtils.mean(sample2), + StatUtils.variance(sample1), StatUtils.variance(sample2), + sample1.length, sample2.length); + + } + + /** + * Computes a 2-sample t statistic </a>, comparing the means of the datasets + * described by two {@link StatisticalSummary} instances, without the + * assumption of equal subpopulation variances. Use + * {@link #homoscedasticT(StatisticalSummary, StatisticalSummary)} to + * compute a t-statistic under the equal variances assumption. + * <p> + * This statistic can be used to perform a two-sample t-test to compare + * sample means.</p> + * <p> + * The returned t-statistic is</p> + * <p> + * <code> t = (m1 - m2) / sqrt(var1/n1 + var2/n2)</code> + * </p><p> + * where <strong><code>n1</code></strong> is the size of the first sample; + * <strong><code> n2</code></strong> is the size of the second sample; + * <strong><code> m1</code></strong> is the mean of the first sample; + * <strong><code> m2</code></strong> is the mean of the second sample + * <strong><code> var1</code></strong> is the variance of the first sample; + * <strong><code> var2</code></strong> is the variance of the second sample + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The datasets described by the two Univariates must each contain + * at least 2 observations. + * </li></ul></p> + * + * @param sampleStats1 StatisticalSummary describing data from the first sample + * @param sampleStats2 StatisticalSummary describing data from the second sample + * @return t statistic + * @throws NullArgumentException if the sample statistics are <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + */ + public double t(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(sampleStats1); + checkSampleData(sampleStats2); + return t(sampleStats1.getMean(), sampleStats2.getMean(), + sampleStats1.getVariance(), sampleStats2.getVariance(), + sampleStats1.getN(), sampleStats2.getN()); + + } + + /** + * Computes a 2-sample t statistic, comparing the means of the datasets + * described by two {@link StatisticalSummary} instances, under the + * assumption of equal subpopulation variances. To compute a t-statistic + * without the equal variances assumption, use + * {@link #t(StatisticalSummary, StatisticalSummary)}. + * <p> + * This statistic can be used to perform a (homoscedastic) two-sample + * t-test to compare sample means.</p> + * <p> + * The t-statistic returned is</p> + * <p> + * <code> t = (m1 - m2) / (sqrt(1/n1 +1/n2) sqrt(var))</code> + * </p><p> + * where <strong><code>n1</code></strong> is the size of first sample; + * <strong><code> n2</code></strong> is the size of second sample; + * <strong><code> m1</code></strong> is the mean of first sample; + * <strong><code> m2</code></strong> is the mean of second sample + * and <strong><code>var</code></strong> is the pooled variance estimate: + * </p><p> + * <code>var = sqrt(((n1 - 1)var1 + (n2 - 1)var2) / ((n1-1) + (n2-1)))</code> + * </p><p> + * with <strong><code>var1</code></strong> the variance of the first sample and + * <strong><code>var2</code></strong> the variance of the second sample. + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The datasets described by the two Univariates must each contain + * at least 2 observations. + * </li></ul></p> + * + * @param sampleStats1 StatisticalSummary describing data from the first sample + * @param sampleStats2 StatisticalSummary describing data from the second sample + * @return t statistic + * @throws NullArgumentException if the sample statistics are <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + */ + public double homoscedasticT(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException { + + checkSampleData(sampleStats1); + checkSampleData(sampleStats2); + return homoscedasticT(sampleStats1.getMean(), sampleStats2.getMean(), + sampleStats1.getVariance(), sampleStats2.getVariance(), + sampleStats1.getN(), sampleStats2.getN()); + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a one-sample, two-tailed t-test + * comparing the mean of the input array with the constant <code>mu</code>. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the mean equals + * <code>mu</code> in favor of the two-sided alternative that the mean + * is different from <code>mu</code>. For a one-sided test, divide the + * returned value by 2.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html">here</a> + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array length must be at least 2. + * </li></ul></p> + * + * @param mu constant value to compare sample mean against + * @param sample array of sample data values + * @return p-value + * @throws NullArgumentException if the sample array is <code>null</code> + * @throws NumberIsTooSmallException if the length of the array is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double tTest(final double mu, final double[] sample) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sample); + // No try-catch or advertised exception because args have just been checked + return tTest(StatUtils.mean(sample), mu, StatUtils.variance(sample), + sample.length); + + } + + /** + * Performs a <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda353.htm"> + * two-sided t-test</a> evaluating the null hypothesis that the mean of the population from + * which <code>sample</code> is drawn equals <code>mu</code>. + * <p> + * Returns <code>true</code> iff the null hypothesis can be + * rejected with confidence <code>1 - alpha</code>. To + * perform a 1-sided test, use <code>alpha * 2</code></p> + * <p> + * <strong>Examples:</strong><br><ol> + * <li>To test the (2-sided) hypothesis <code>sample mean = mu </code> at + * the 95% level, use <br><code>tTest(mu, sample, 0.05) </code> + * </li> + * <li>To test the (one-sided) hypothesis <code> sample mean < mu </code> + * at the 99% level, first verify that the measured sample mean is less + * than <code>mu</code> and then use + * <br><code>tTest(mu, sample, 0.02) </code> + * </li></ol></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the one-sample + * parametric t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/sg_glos.html#one-sample">here</a> + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array length must be at least 2. + * </li></ul></p> + * + * @param mu constant value to compare sample mean against + * @param sample array of sample data values + * @param alpha significance level of the test + * @return p-value + * @throws NullArgumentException if the sample array is <code>null</code> + * @throws NumberIsTooSmallException if the length of the array is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error computing the p-value + */ + public boolean tTest(final double mu, final double[] sample, final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return tTest(mu, sample) < alpha; + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a one-sample, two-tailed t-test + * comparing the mean of the dataset described by <code>sampleStats</code> + * with the constant <code>mu</code>. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the mean equals + * <code>mu</code> in favor of the two-sided alternative that the mean + * is different from <code>mu</code>. For a one-sided test, divide the + * returned value by 2.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The sample must contain at least 2 observations. + * </li></ul></p> + * + * @param mu constant value to compare sample mean against + * @param sampleStats StatisticalSummary describing sample data + * @return p-value + * @throws NullArgumentException if <code>sampleStats</code> is <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double tTest(final double mu, final StatisticalSummary sampleStats) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sampleStats); + return tTest(sampleStats.getMean(), mu, sampleStats.getVariance(), + sampleStats.getN()); + + } + + /** + * Performs a <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda353.htm"> + * two-sided t-test</a> evaluating the null hypothesis that the mean of the + * population from which the dataset described by <code>stats</code> is + * drawn equals <code>mu</code>. + * <p> + * Returns <code>true</code> iff the null hypothesis can be rejected with + * confidence <code>1 - alpha</code>. To perform a 1-sided test, use + * <code>alpha * 2.</code></p> + * <p> + * <strong>Examples:</strong><br><ol> + * <li>To test the (2-sided) hypothesis <code>sample mean = mu </code> at + * the 95% level, use <br><code>tTest(mu, sampleStats, 0.05) </code> + * </li> + * <li>To test the (one-sided) hypothesis <code> sample mean < mu </code> + * at the 99% level, first verify that the measured sample mean is less + * than <code>mu</code> and then use + * <br><code>tTest(mu, sampleStats, 0.02) </code> + * </li></ol></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the one-sample + * parametric t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/sg_glos.html#one-sample">here</a> + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The sample must include at least 2 observations. + * </li></ul></p> + * + * @param mu constant value to compare sample mean against + * @param sampleStats StatisticalSummary describing sample data values + * @param alpha significance level of the test + * @return p-value + * @throws NullArgumentException if <code>sampleStats</code> is <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean tTest(final double mu, final StatisticalSummary sampleStats, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return tTest(mu, sampleStats) < alpha; + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a two-sample, two-tailed t-test + * comparing the means of the input arrays. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the two means are + * equal in favor of the two-sided alternative that they are different. + * For a one-sided test, divide the returned value by 2.</p> + * <p> + * The test does not assume that the underlying popuation variances are + * equal and it uses approximated degrees of freedom computed from the + * sample data to compute the p-value. The t-statistic used is as defined in + * {@link #t(double[], double[])} and the Welch-Satterthwaite approximation + * to the degrees of freedom is used, + * as described + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section3/prc31.htm"> + * here.</a> To perform the test under the assumption of equal subpopulation + * variances, use {@link #homoscedasticTTest(double[], double[])}.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the p-value depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return p-value for t-test + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double tTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sample1); + checkSampleData(sample2); + // No try-catch or advertised exception because args have just been checked + return tTest(StatUtils.mean(sample1), StatUtils.mean(sample2), + StatUtils.variance(sample1), StatUtils.variance(sample2), + sample1.length, sample2.length); + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a two-sample, two-tailed t-test + * comparing the means of the input arrays, under the assumption that + * the two samples are drawn from subpopulations with equal variances. + * To perform the test without the equal variances assumption, use + * {@link #tTest(double[], double[])}.</p> + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the two means are + * equal in favor of the two-sided alternative that they are different. + * For a one-sided test, divide the returned value by 2.</p> + * <p> + * A pooled variance estimate is used to compute the t-statistic. See + * {@link #homoscedasticT(double[], double[])}. The sum of the sample sizes + * minus 2 is used as the degrees of freedom.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the p-value depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @return p-value for t-test + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double homoscedasticTTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sample1); + checkSampleData(sample2); + // No try-catch or advertised exception because args have just been checked + return homoscedasticTTest(StatUtils.mean(sample1), + StatUtils.mean(sample2), + StatUtils.variance(sample1), + StatUtils.variance(sample2), + sample1.length, sample2.length); + + } + + /** + * Performs a + * <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda353.htm"> + * two-sided t-test</a> evaluating the null hypothesis that <code>sample1</code> + * and <code>sample2</code> are drawn from populations with the same mean, + * with significance level <code>alpha</code>. This test does not assume + * that the subpopulation variances are equal. To perform the test assuming + * equal variances, use + * {@link #homoscedasticTTest(double[], double[], double)}. + * <p> + * Returns <code>true</code> iff the null hypothesis that the means are + * equal can be rejected with confidence <code>1 - alpha</code>. To + * perform a 1-sided test, use <code>alpha * 2</code></p> + * <p> + * See {@link #t(double[], double[])} for the formula used to compute the + * t-statistic. Degrees of freedom are approximated using the + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section3/prc31.htm"> + * Welch-Satterthwaite approximation.</a></p> + * <p> + * <strong>Examples:</strong><br><ol> + * <li>To test the (2-sided) hypothesis <code>mean 1 = mean 2 </code> at + * the 95% level, use + * <br><code>tTest(sample1, sample2, 0.05). </code> + * </li> + * <li>To test the (one-sided) hypothesis <code> mean 1 < mean 2 </code>, + * at the 99% level, first verify that the measured mean of <code>sample 1</code> + * is less than the mean of <code>sample 2</code> and then use + * <br><code>tTest(sample1, sample2, 0.02) </code> + * </li></ol></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li> + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with + * confidence 1 - alpha + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean tTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return tTest(sample1, sample2) < alpha; + + } + + /** + * Performs a + * <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda353.htm"> + * two-sided t-test</a> evaluating the null hypothesis that <code>sample1</code> + * and <code>sample2</code> are drawn from populations with the same mean, + * with significance level <code>alpha</code>, assuming that the + * subpopulation variances are equal. Use + * {@link #tTest(double[], double[], double)} to perform the test without + * the assumption of equal variances. + * <p> + * Returns <code>true</code> iff the null hypothesis that the means are + * equal can be rejected with confidence <code>1 - alpha</code>. To + * perform a 1-sided test, use <code>alpha * 2.</code> To perform the test + * without the assumption of equal subpopulation variances, use + * {@link #tTest(double[], double[], double)}.</p> + * <p> + * A pooled variance estimate is used to compute the t-statistic. See + * {@link #t(double[], double[])} for the formula. The sum of the sample + * sizes minus 2 is used as the degrees of freedom.</p> + * <p> + * <strong>Examples:</strong><br><ol> + * <li>To test the (2-sided) hypothesis <code>mean 1 = mean 2 </code> at + * the 95% level, use <br><code>tTest(sample1, sample2, 0.05). </code> + * </li> + * <li>To test the (one-sided) hypothesis <code> mean 1 < mean 2, </code> + * at the 99% level, first verify that the measured mean of + * <code>sample 1</code> is less than the mean of <code>sample 2</code> + * and then use + * <br><code>tTest(sample1, sample2, 0.02) </code> + * </li></ol></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The observed array lengths must both be at least 2. + * </li> + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul></p> + * + * @param sample1 array of sample data values + * @param sample2 array of sample data values + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with + * confidence 1 - alpha + * @throws NullArgumentException if the arrays are <code>null</code> + * @throws NumberIsTooSmallException if the length of the arrays is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean homoscedasticTTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return homoscedasticTTest(sample1, sample2) < alpha; + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a two-sample, two-tailed t-test + * comparing the means of the datasets described by two StatisticalSummary + * instances. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the two means are + * equal in favor of the two-sided alternative that they are different. + * For a one-sided test, divide the returned value by 2.</p> + * <p> + * The test does not assume that the underlying population variances are + * equal and it uses approximated degrees of freedom computed from the + * sample data to compute the p-value. To perform the test assuming + * equal variances, use + * {@link #homoscedasticTTest(StatisticalSummary, StatisticalSummary)}.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the p-value depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The datasets described by the two Univariates must each contain + * at least 2 observations. + * </li></ul></p> + * + * @param sampleStats1 StatisticalSummary describing data from the first sample + * @param sampleStats2 StatisticalSummary describing data from the second sample + * @return p-value for t-test + * @throws NullArgumentException if the sample statistics are <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double tTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sampleStats1); + checkSampleData(sampleStats2); + return tTest(sampleStats1.getMean(), sampleStats2.getMean(), + sampleStats1.getVariance(), sampleStats2.getVariance(), + sampleStats1.getN(), sampleStats2.getN()); + + } + + /** + * Returns the <i>observed significance level</i>, or + * <i>p-value</i>, associated with a two-sample, two-tailed t-test + * comparing the means of the datasets described by two StatisticalSummary + * instances, under the hypothesis of equal subpopulation variances. To + * perform a test without the equal variances assumption, use + * {@link #tTest(StatisticalSummary, StatisticalSummary)}. + * <p> + * The number returned is the smallest significance level + * at which one can reject the null hypothesis that the two means are + * equal in favor of the two-sided alternative that they are different. + * For a one-sided test, divide the returned value by 2.</p> + * <p> + * See {@link #homoscedasticT(double[], double[])} for the formula used to + * compute the t-statistic. The sum of the sample sizes minus 2 is used as + * the degrees of freedom.</p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the p-value depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html">here</a> + * </p><p> + * <strong>Preconditions</strong>: <ul> + * <li>The datasets described by the two Univariates must each contain + * at least 2 observations. + * </li></ul></p> + * + * @param sampleStats1 StatisticalSummary describing data from the first sample + * @param sampleStats2 StatisticalSummary describing data from the second sample + * @return p-value for t-test + * @throws NullArgumentException if the sample statistics are <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public double homoscedasticTTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + + checkSampleData(sampleStats1); + checkSampleData(sampleStats2); + return homoscedasticTTest(sampleStats1.getMean(), + sampleStats2.getMean(), + sampleStats1.getVariance(), + sampleStats2.getVariance(), + sampleStats1.getN(), sampleStats2.getN()); + + } + + /** + * Performs a + * <a href="http://www.itl.nist.gov/div898/handbook/eda/section3/eda353.htm"> + * two-sided t-test</a> evaluating the null hypothesis that + * <code>sampleStats1</code> and <code>sampleStats2</code> describe + * datasets drawn from populations with the same mean, with significance + * level <code>alpha</code>. This test does not assume that the + * subpopulation variances are equal. To perform the test under the equal + * variances assumption, use + * {@link #homoscedasticTTest(StatisticalSummary, StatisticalSummary)}. + * <p> + * Returns <code>true</code> iff the null hypothesis that the means are + * equal can be rejected with confidence <code>1 - alpha</code>. To + * perform a 1-sided test, use <code>alpha * 2</code></p> + * <p> + * See {@link #t(double[], double[])} for the formula used to compute the + * t-statistic. Degrees of freedom are approximated using the + * <a href="http://www.itl.nist.gov/div898/handbook/prc/section3/prc31.htm"> + * Welch-Satterthwaite approximation.</a></p> + * <p> + * <strong>Examples:</strong><br><ol> + * <li>To test the (2-sided) hypothesis <code>mean 1 = mean 2 </code> at + * the 95%, use + * <br><code>tTest(sampleStats1, sampleStats2, 0.05) </code> + * </li> + * <li>To test the (one-sided) hypothesis <code> mean 1 < mean 2 </code> + * at the 99% level, first verify that the measured mean of + * <code>sample 1</code> is less than the mean of <code>sample 2</code> + * and then use + * <br><code>tTest(sampleStats1, sampleStats2, 0.02) </code> + * </li></ol></p> + * <p> + * <strong>Usage Note:</strong><br> + * The validity of the test depends on the assumptions of the parametric + * t-test procedure, as discussed + * <a href="http://www.basic.nwu.edu/statguidefiles/ttest_unpaired_ass_viol.html"> + * here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>The datasets described by the two Univariates must each contain + * at least 2 observations. + * </li> + * <li> <code> 0 < alpha < 0.5 </code> + * </li></ul></p> + * + * @param sampleStats1 StatisticalSummary describing sample data values + * @param sampleStats2 StatisticalSummary describing sample data values + * @param alpha significance level of the test + * @return true if the null hypothesis can be rejected with + * confidence 1 - alpha + * @throws NullArgumentException if the sample statistics are <code>null</code> + * @throws NumberIsTooSmallException if the number of samples is < 2 + * @throws OutOfRangeException if <code>alpha</code> is not in the range (0, 0.5] + * @throws MaxCountExceededException if an error occurs computing the p-value + */ + public boolean tTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + + checkSignificanceLevel(alpha); + return tTest(sampleStats1, sampleStats2) < alpha; + + } + + //----------------------------------------------- Protected methods + + /** + * Computes approximate degrees of freedom for 2-sample t-test. + * + * @param v1 first sample variance + * @param v2 second sample variance + * @param n1 first sample n + * @param n2 second sample n + * @return approximate degrees of freedom + */ + protected double df(double v1, double v2, double n1, double n2) { + return (((v1 / n1) + (v2 / n2)) * ((v1 / n1) + (v2 / n2))) / + ((v1 * v1) / (n1 * n1 * (n1 - 1d)) + (v2 * v2) / + (n2 * n2 * (n2 - 1d))); + } + + /** + * Computes t test statistic for 1-sample t-test. + * + * @param m sample mean + * @param mu constant to test against + * @param v sample variance + * @param n sample n + * @return t test statistic + */ + protected double t(final double m, final double mu, + final double v, final double n) { + return (m - mu) / FastMath.sqrt(v / n); + } + + /** + * Computes t test statistic for 2-sample t-test. + * <p> + * Does not assume that subpopulation variances are equal.</p> + * + * @param m1 first sample mean + * @param m2 second sample mean + * @param v1 first sample variance + * @param v2 second sample variance + * @param n1 first sample n + * @param n2 second sample n + * @return t test statistic + */ + protected double t(final double m1, final double m2, + final double v1, final double v2, + final double n1, final double n2) { + return (m1 - m2) / FastMath.sqrt((v1 / n1) + (v2 / n2)); + } + + /** + * Computes t test statistic for 2-sample t-test under the hypothesis + * of equal subpopulation variances. + * + * @param m1 first sample mean + * @param m2 second sample mean + * @param v1 first sample variance + * @param v2 second sample variance + * @param n1 first sample n + * @param n2 second sample n + * @return t test statistic + */ + protected double homoscedasticT(final double m1, final double m2, + final double v1, final double v2, + final double n1, final double n2) { + final double pooledVariance = ((n1 - 1) * v1 + (n2 -1) * v2 ) / (n1 + n2 - 2); + return (m1 - m2) / FastMath.sqrt(pooledVariance * (1d / n1 + 1d / n2)); + } + + /** + * Computes p-value for 2-sided, 1-sample t-test. + * + * @param m sample mean + * @param mu constant to test against + * @param v sample variance + * @param n sample n + * @return p-value + * @throws MaxCountExceededException if an error occurs computing the p-value + * @throws MathIllegalArgumentException if n is not greater than 1 + */ + protected double tTest(final double m, final double mu, + final double v, final double n) + throws MaxCountExceededException, MathIllegalArgumentException { + + final double t = FastMath.abs(t(m, mu, v, n)); + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final TDistribution distribution = new TDistribution(null, n - 1); + return 2.0 * distribution.cumulativeProbability(-t); + + } + + /** + * Computes p-value for 2-sided, 2-sample t-test. + * <p> + * Does not assume subpopulation variances are equal. Degrees of freedom + * are estimated from the data.</p> + * + * @param m1 first sample mean + * @param m2 second sample mean + * @param v1 first sample variance + * @param v2 second sample variance + * @param n1 first sample n + * @param n2 second sample n + * @return p-value + * @throws MaxCountExceededException if an error occurs computing the p-value + * @throws NotStrictlyPositiveException if the estimated degrees of freedom is not + * strictly positive + */ + protected double tTest(final double m1, final double m2, + final double v1, final double v2, + final double n1, final double n2) + throws MaxCountExceededException, NotStrictlyPositiveException { + + final double t = FastMath.abs(t(m1, m2, v1, v2, n1, n2)); + final double degreesOfFreedom = df(v1, v2, n1, n2); + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final TDistribution distribution = new TDistribution(null, degreesOfFreedom); + return 2.0 * distribution.cumulativeProbability(-t); + + } + + /** + * Computes p-value for 2-sided, 2-sample t-test, under the assumption + * of equal subpopulation variances. + * <p> + * The sum of the sample sizes minus 2 is used as degrees of freedom.</p> + * + * @param m1 first sample mean + * @param m2 second sample mean + * @param v1 first sample variance + * @param v2 second sample variance + * @param n1 first sample n + * @param n2 second sample n + * @return p-value + * @throws MaxCountExceededException if an error occurs computing the p-value + * @throws NotStrictlyPositiveException if the estimated degrees of freedom is not + * strictly positive + */ + protected double homoscedasticTTest(double m1, double m2, + double v1, double v2, + double n1, double n2) + throws MaxCountExceededException, NotStrictlyPositiveException { + + final double t = FastMath.abs(homoscedasticT(m1, m2, v1, v2, n1, n2)); + final double degreesOfFreedom = n1 + n2 - 2; + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final TDistribution distribution = new TDistribution(null, degreesOfFreedom); + return 2.0 * distribution.cumulativeProbability(-t); + + } + + /** + * Check significance level. + * + * @param alpha significance level + * @throws OutOfRangeException if the significance level is out of bounds. + */ + private void checkSignificanceLevel(final double alpha) + throws OutOfRangeException { + + if (alpha <= 0 || alpha > 0.5) { + throw new OutOfRangeException(LocalizedFormats.SIGNIFICANCE_LEVEL, + alpha, 0.0, 0.5); + } + + } + + /** + * Check sample data. + * + * @param data Sample data. + * @throws NullArgumentException if {@code data} is {@code null}. + * @throws NumberIsTooSmallException if there is not enough sample data. + */ + private void checkSampleData(final double[] data) + throws NullArgumentException, NumberIsTooSmallException { + + if (data == null) { + throw new NullArgumentException(); + } + if (data.length < 2) { + throw new NumberIsTooSmallException( + LocalizedFormats.INSUFFICIENT_DATA_FOR_T_STATISTIC, + data.length, 2, true); + } + + } + + /** + * Check sample data. + * + * @param stat Statistical summary. + * @throws NullArgumentException if {@code data} is {@code null}. + * @throws NumberIsTooSmallException if there is not enough sample data. + */ + private void checkSampleData(final StatisticalSummary stat) + throws NullArgumentException, NumberIsTooSmallException { + + if (stat == null) { + throw new NullArgumentException(); + } + if (stat.getN() < 2) { + throw new NumberIsTooSmallException( + LocalizedFormats.INSUFFICIENT_DATA_FOR_T_STATISTIC, + stat.getN(), 2, true); + } + + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/TestUtils.java b/src/main/java/org/apache/commons/math3/stat/inference/TestUtils.java new file mode 100644 index 0000000..a92fb19 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/TestUtils.java @@ -0,0 +1,547 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import java.util.Collection; + +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.InsufficientDataException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.stat.descriptive.StatisticalSummary; + +/** + * A collection of static methods to create inference test instances or to + * perform inference tests. + * + * @since 1.1 + */ +public class TestUtils { + + /** Singleton TTest instance. */ + private static final TTest T_TEST = new TTest(); + + /** Singleton ChiSquareTest instance. */ + private static final ChiSquareTest CHI_SQUARE_TEST = new ChiSquareTest(); + + /** Singleton OneWayAnova instance. */ + private static final OneWayAnova ONE_WAY_ANANOVA = new OneWayAnova(); + + /** Singleton G-Test instance. */ + private static final GTest G_TEST = new GTest(); + + /** Singleton K-S test instance */ + private static final KolmogorovSmirnovTest KS_TEST = new KolmogorovSmirnovTest(); + + /** + * Prevent instantiation. + */ + private TestUtils() { + super(); + } + + // CHECKSTYLE: stop JavadocMethodCheck + + /** + * @see org.apache.commons.math3.stat.inference.TTest#homoscedasticT(double[], double[]) + */ + public static double homoscedasticT(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.homoscedasticT(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#homoscedasticT(org.apache.commons.math3.stat.descriptive.StatisticalSummary, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double homoscedasticT(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.homoscedasticT(sampleStats1, sampleStats2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#homoscedasticTTest(double[], double[], double) + */ + public static boolean homoscedasticTTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + return T_TEST.homoscedasticTTest(sample1, sample2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#homoscedasticTTest(double[], double[]) + */ + public static double homoscedasticTTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException, MaxCountExceededException { + return T_TEST.homoscedasticTTest(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#homoscedasticTTest(org.apache.commons.math3.stat.descriptive.StatisticalSummary, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double homoscedasticTTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException, MaxCountExceededException { + return T_TEST.homoscedasticTTest(sampleStats1, sampleStats2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#pairedT(double[], double[]) + */ + public static double pairedT(final double[] sample1, final double[] sample2) + throws NullArgumentException, NoDataException, + DimensionMismatchException, NumberIsTooSmallException { + return T_TEST.pairedT(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#pairedTTest(double[], double[], double) + */ + public static boolean pairedTTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NoDataException, DimensionMismatchException, + NumberIsTooSmallException, OutOfRangeException, MaxCountExceededException { + return T_TEST.pairedTTest(sample1, sample2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#pairedTTest(double[], double[]) + */ + public static double pairedTTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NoDataException, DimensionMismatchException, + NumberIsTooSmallException, MaxCountExceededException { + return T_TEST.pairedTTest(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#t(double, double[]) + */ + public static double t(final double mu, final double[] observed) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.t(mu, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#t(double, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double t(final double mu, final StatisticalSummary sampleStats) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.t(mu, sampleStats); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#t(double[], double[]) + */ + public static double t(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.t(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#t(org.apache.commons.math3.stat.descriptive.StatisticalSummary, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double t(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException { + return T_TEST.t(sampleStats1, sampleStats2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double, double[], double) + */ + public static boolean tTest(final double mu, final double[] sample, final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + return T_TEST.tTest(mu, sample, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double, double[]) + */ + public static double tTest(final double mu, final double[] sample) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + return T_TEST.tTest(mu, sample); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double, org.apache.commons.math3.stat.descriptive.StatisticalSummary, double) + */ + public static boolean tTest(final double mu, final StatisticalSummary sampleStats, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + return T_TEST.tTest(mu, sampleStats, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double tTest(final double mu, final StatisticalSummary sampleStats) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + return T_TEST.tTest(mu, sampleStats); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double[], double[], double) + */ + public static boolean tTest(final double[] sample1, final double[] sample2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + return T_TEST.tTest(sample1, sample2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(double[], double[]) + */ + public static double tTest(final double[] sample1, final double[] sample2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + return T_TEST.tTest(sample1, sample2); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(org.apache.commons.math3.stat.descriptive.StatisticalSummary, org.apache.commons.math3.stat.descriptive.StatisticalSummary, double) + */ + public static boolean tTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2, + final double alpha) + throws NullArgumentException, NumberIsTooSmallException, + OutOfRangeException, MaxCountExceededException { + return T_TEST.tTest(sampleStats1, sampleStats2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.TTest#tTest(org.apache.commons.math3.stat.descriptive.StatisticalSummary, org.apache.commons.math3.stat.descriptive.StatisticalSummary) + */ + public static double tTest(final StatisticalSummary sampleStats1, + final StatisticalSummary sampleStats2) + throws NullArgumentException, NumberIsTooSmallException, + MaxCountExceededException { + return T_TEST.tTest(sampleStats1, sampleStats2); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquare(double[], long[]) + */ + public static double chiSquare(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException { + return CHI_SQUARE_TEST.chiSquare(expected, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquare(long[][]) + */ + public static double chiSquare(final long[][] counts) + throws NullArgumentException, NotPositiveException, + DimensionMismatchException { + return CHI_SQUARE_TEST.chiSquare(counts); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTest(double[], long[], double) + */ + public static boolean chiSquareTest(final double[] expected, final long[] observed, + final double alpha) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, OutOfRangeException, MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTest(expected, observed, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTest(double[], long[]) + */ + public static double chiSquareTest(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTest(expected, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTest(long[][], double) + */ + public static boolean chiSquareTest(final long[][] counts, final double alpha) + throws NullArgumentException, DimensionMismatchException, + NotPositiveException, OutOfRangeException, MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTest(counts, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTest(long[][]) + */ + public static double chiSquareTest(final long[][] counts) + throws NullArgumentException, DimensionMismatchException, + NotPositiveException, MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTest(counts); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareDataSetsComparison(long[], long[]) + * + * @since 1.2 + */ + public static double chiSquareDataSetsComparison(final long[] observed1, + final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException { + return CHI_SQUARE_TEST.chiSquareDataSetsComparison(observed1, observed2); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTestDataSetsComparison(long[], long[]) + * + * @since 1.2 + */ + public static double chiSquareTestDataSetsComparison(final long[] observed1, + final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException, + MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTestDataSetsComparison(observed1, observed2); + } + + /** + * @see org.apache.commons.math3.stat.inference.ChiSquareTest#chiSquareTestDataSetsComparison(long[], long[], double) + * + * @since 1.2 + */ + public static boolean chiSquareTestDataSetsComparison(final long[] observed1, + final long[] observed2, + final double alpha) + throws DimensionMismatchException, NotPositiveException, + ZeroException, OutOfRangeException, MaxCountExceededException { + return CHI_SQUARE_TEST.chiSquareTestDataSetsComparison(observed1, observed2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.OneWayAnova#anovaFValue(Collection) + * + * @since 1.2 + */ + public static double oneWayAnovaFValue(final Collection<double[]> categoryData) + throws NullArgumentException, DimensionMismatchException { + return ONE_WAY_ANANOVA.anovaFValue(categoryData); + } + + /** + * @see org.apache.commons.math3.stat.inference.OneWayAnova#anovaPValue(Collection) + * + * @since 1.2 + */ + public static double oneWayAnovaPValue(final Collection<double[]> categoryData) + throws NullArgumentException, DimensionMismatchException, + ConvergenceException, MaxCountExceededException { + return ONE_WAY_ANANOVA.anovaPValue(categoryData); + } + + /** + * @see org.apache.commons.math3.stat.inference.OneWayAnova#anovaTest(Collection,double) + * + * @since 1.2 + */ + public static boolean oneWayAnovaTest(final Collection<double[]> categoryData, + final double alpha) + throws NullArgumentException, DimensionMismatchException, + OutOfRangeException, ConvergenceException, MaxCountExceededException { + return ONE_WAY_ANANOVA.anovaTest(categoryData, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#g(double[], long[]) + * @since 3.1 + */ + public static double g(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException { + return G_TEST.g(expected, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gTest( double[], long[] ) + * @since 3.1 + */ + public static double gTest(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + return G_TEST.gTest(expected, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gTestIntrinsic(double[], long[] ) + * @since 3.1 + */ + public static double gTestIntrinsic(final double[] expected, final long[] observed) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, MaxCountExceededException { + return G_TEST.gTestIntrinsic(expected, observed); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gTest( double[],long[],double) + * @since 3.1 + */ + public static boolean gTest(final double[] expected, final long[] observed, + final double alpha) + throws NotPositiveException, NotStrictlyPositiveException, + DimensionMismatchException, OutOfRangeException, MaxCountExceededException { + return G_TEST.gTest(expected, observed, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gDataSetsComparison(long[], long[]) + * @since 3.1 + */ + public static double gDataSetsComparison(final long[] observed1, + final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException { + return G_TEST.gDataSetsComparison(observed1, observed2); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#rootLogLikelihoodRatio(long, long, long, long) + * @since 3.1 + */ + public static double rootLogLikelihoodRatio(final long k11, final long k12, final long k21, final long k22) + throws DimensionMismatchException, NotPositiveException, ZeroException { + return G_TEST.rootLogLikelihoodRatio(k11, k12, k21, k22); + } + + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gTestDataSetsComparison(long[], long[]) + * @since 3.1 + */ + public static double gTestDataSetsComparison(final long[] observed1, + final long[] observed2) + throws DimensionMismatchException, NotPositiveException, ZeroException, + MaxCountExceededException { + return G_TEST.gTestDataSetsComparison(observed1, observed2); + } + + /** + * @see org.apache.commons.math3.stat.inference.GTest#gTestDataSetsComparison(long[],long[],double) + * @since 3.1 + */ + public static boolean gTestDataSetsComparison(final long[] observed1, + final long[] observed2, + final double alpha) + throws DimensionMismatchException, NotPositiveException, + ZeroException, OutOfRangeException, MaxCountExceededException { + return G_TEST.gTestDataSetsComparison(observed1, observed2, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovStatistic(RealDistribution, double[]) + * @since 3.3 + */ + public static double kolmogorovSmirnovStatistic(RealDistribution dist, double[] data) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovStatistic(dist, data); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovTest(RealDistribution, double[]) + * @since 3.3 + */ + public static double kolmogorovSmirnovTest(RealDistribution dist, double[] data) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovTest(dist, data); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovTest(RealDistribution, double[], boolean) + * @since 3.3 + */ + public static double kolmogorovSmirnovTest(RealDistribution dist, double[] data, boolean strict) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovTest(dist, data, strict); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovTest(RealDistribution, double[], double) + * @since 3.3 + */ + public static boolean kolmogorovSmirnovTest(RealDistribution dist, double[] data, double alpha) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovTest(dist, data, alpha); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovStatistic(double[], double[]) + * @since 3.3 + */ + public static double kolmogorovSmirnovStatistic(double[] x, double[] y) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovStatistic(x, y); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovTest(double[], double[]) + * @since 3.3 + */ + public static double kolmogorovSmirnovTest(double[] x, double[] y) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovTest(x, y); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#kolmogorovSmirnovTest(double[], double[], boolean) + * @since 3.3 + */ + public static double kolmogorovSmirnovTest(double[] x, double[] y, boolean strict) + throws InsufficientDataException, NullArgumentException { + return KS_TEST.kolmogorovSmirnovTest(x, y, strict); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#exactP(double, int, int, boolean) + * @since 3.3 + */ + public static double exactP(double d, int m, int n, boolean strict) { + return KS_TEST.exactP(d, n, m, strict); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#approximateP(double, int, int) + * @since 3.3 + */ + public static double approximateP(double d, int n, int m) { + return KS_TEST.approximateP(d, n, m); + } + + /** + * @see org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest#monteCarloP(double, int, int, boolean, int) + * @since 3.3 + */ + public static double monteCarloP(double d, int n, int m, boolean strict, int iterations) { + return KS_TEST.monteCarloP(d, n, m, strict, iterations); + } + + + // CHECKSTYLE: resume JavadocMethodCheck + +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/WilcoxonSignedRankTest.java b/src/main/java/org/apache/commons/math3/stat/inference/WilcoxonSignedRankTest.java new file mode 100644 index 0000000..bd4d7e2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/WilcoxonSignedRankTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.inference; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MaxCountExceededException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.stat.ranking.NaNStrategy; +import org.apache.commons.math3.stat.ranking.NaturalRanking; +import org.apache.commons.math3.stat.ranking.TiesStrategy; +import org.apache.commons.math3.util.FastMath; + +/** + * An implementation of the Wilcoxon signed-rank test. + * + */ +public class WilcoxonSignedRankTest { + + /** Ranking algorithm. */ + private NaturalRanking naturalRanking; + + /** + * Create a test instance where NaN's are left in place and ties get + * the average of applicable ranks. Use this unless you are very sure + * of what you are doing. + */ + public WilcoxonSignedRankTest() { + naturalRanking = new NaturalRanking(NaNStrategy.FIXED, + TiesStrategy.AVERAGE); + } + + /** + * Create a test instance using the given strategies for NaN's and ties. + * Only use this if you are sure of what you are doing. + * + * @param nanStrategy + * specifies the strategy that should be used for Double.NaN's + * @param tiesStrategy + * specifies the strategy that should be used for ties + */ + public WilcoxonSignedRankTest(final NaNStrategy nanStrategy, + final TiesStrategy tiesStrategy) { + naturalRanking = new NaturalRanking(nanStrategy, tiesStrategy); + } + + /** + * Ensures that the provided arrays fulfills the assumptions. + * + * @param x first sample + * @param y second sample + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + * @throws DimensionMismatchException if {@code x} and {@code y} do not + * have the same length. + */ + private void ensureDataConformance(final double[] x, final double[] y) + throws NullArgumentException, NoDataException, DimensionMismatchException { + + if (x == null || + y == null) { + throw new NullArgumentException(); + } + if (x.length == 0 || + y.length == 0) { + throw new NoDataException(); + } + if (y.length != x.length) { + throw new DimensionMismatchException(y.length, x.length); + } + } + + /** + * Calculates y[i] - x[i] for all i + * + * @param x first sample + * @param y second sample + * @return z = y - x + */ + private double[] calculateDifferences(final double[] x, final double[] y) { + + final double[] z = new double[x.length]; + + for (int i = 0; i < x.length; ++i) { + z[i] = y[i] - x[i]; + } + + return z; + } + + /** + * Calculates |z[i]| for all i + * + * @param z sample + * @return |z| + * @throws NullArgumentException if {@code z} is {@code null} + * @throws NoDataException if {@code z} is zero-length. + */ + private double[] calculateAbsoluteDifferences(final double[] z) + throws NullArgumentException, NoDataException { + + if (z == null) { + throw new NullArgumentException(); + } + + if (z.length == 0) { + throw new NoDataException(); + } + + final double[] zAbs = new double[z.length]; + + for (int i = 0; i < z.length; ++i) { + zAbs[i] = FastMath.abs(z[i]); + } + + return zAbs; + } + + /** + * Computes the <a + * href="http://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test"> + * Wilcoxon signed ranked statistic</a> comparing mean for two related + * samples or repeated measurements on a single sample. + * <p> + * This statistic can be used to perform a Wilcoxon signed ranked test + * evaluating the null hypothesis that the two related samples or repeated + * measurements on a single sample has equal mean. + * </p> + * <p> + * Let X<sub>i</sub> denote the i'th individual of the first sample and + * Y<sub>i</sub> the related i'th individual in the second sample. Let + * Z<sub>i</sub> = Y<sub>i</sub> - X<sub>i</sub>. + * </p> + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>The differences Z<sub>i</sub> must be independent.</li> + * <li>Each Z<sub>i</sub> comes from a continuous population (they must be + * identical) and is symmetric about a common median.</li> + * <li>The values that X<sub>i</sub> and Y<sub>i</sub> represent are + * ordered, so the comparisons greater than, less than, and equal to are + * meaningful.</li> + * </ul> + * </p> + * + * @param x the first sample + * @param y the second sample + * @return wilcoxonSignedRank statistic (the larger of W+ and W-) + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + * @throws DimensionMismatchException if {@code x} and {@code y} do not + * have the same length. + */ + public double wilcoxonSignedRank(final double[] x, final double[] y) + throws NullArgumentException, NoDataException, DimensionMismatchException { + + ensureDataConformance(x, y); + + // throws IllegalArgumentException if x and y are not correctly + // specified + final double[] z = calculateDifferences(x, y); + final double[] zAbs = calculateAbsoluteDifferences(z); + + final double[] ranks = naturalRanking.rank(zAbs); + + double Wplus = 0; + + for (int i = 0; i < z.length; ++i) { + if (z[i] > 0) { + Wplus += ranks[i]; + } + } + + final int N = x.length; + final double Wminus = (((double) (N * (N + 1))) / 2.0) - Wplus; + + return FastMath.max(Wplus, Wminus); + } + + /** + * Algorithm inspired by + * http://www.fon.hum.uva.nl/Service/Statistics/Signed_Rank_Algorihms.html#C + * by Rob van Son, Institute of Phonetic Sciences & IFOTT, + * University of Amsterdam + * + * @param Wmax largest Wilcoxon signed rank value + * @param N number of subjects (corresponding to x.length) + * @return two-sided exact p-value + */ + private double calculateExactPValue(final double Wmax, final int N) { + + // Total number of outcomes (equal to 2^N but a lot faster) + final int m = 1 << N; + + int largerRankSums = 0; + + for (int i = 0; i < m; ++i) { + int rankSum = 0; + + // Generate all possible rank sums + for (int j = 0; j < N; ++j) { + + // (i >> j) & 1 extract i's j-th bit from the right + if (((i >> j) & 1) == 1) { + rankSum += j + 1; + } + } + + if (rankSum >= Wmax) { + ++largerRankSums; + } + } + + /* + * largerRankSums / m gives the one-sided p-value, so it's multiplied + * with 2 to get the two-sided p-value + */ + return 2 * ((double) largerRankSums) / ((double) m); + } + + /** + * @param Wmin smallest Wilcoxon signed rank value + * @param N number of subjects (corresponding to x.length) + * @return two-sided asymptotic p-value + */ + private double calculateAsymptoticPValue(final double Wmin, final int N) { + + final double ES = (double) (N * (N + 1)) / 4.0; + + /* Same as (but saves computations): + * final double VarW = ((double) (N * (N + 1) * (2*N + 1))) / 24; + */ + final double VarS = ES * ((double) (2 * N + 1) / 6.0); + + // - 0.5 is a continuity correction + final double z = (Wmin - ES - 0.5) / FastMath.sqrt(VarS); + + // No try-catch or advertised exception because args are valid + // pass a null rng to avoid unneeded overhead as we will not sample from this distribution + final NormalDistribution standardNormal = new NormalDistribution(null, 0, 1); + + return 2*standardNormal.cumulativeProbability(z); + } + + /** + * Returns the <i>observed significance level</i>, or <a href= + * "http://www.cas.lancs.ac.uk/glossary_v1.1/hyptest.html#pvalue"> + * p-value</a>, associated with a <a + * href="http://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test"> + * Wilcoxon signed ranked statistic</a> comparing mean for two related + * samples or repeated measurements on a single sample. + * <p> + * Let X<sub>i</sub> denote the i'th individual of the first sample and + * Y<sub>i</sub> the related i'th individual in the second sample. Let + * Z<sub>i</sub> = Y<sub>i</sub> - X<sub>i</sub>. + * </p> + * <p> + * <strong>Preconditions</strong>: + * <ul> + * <li>The differences Z<sub>i</sub> must be independent.</li> + * <li>Each Z<sub>i</sub> comes from a continuous population (they must be + * identical) and is symmetric about a common median.</li> + * <li>The values that X<sub>i</sub> and Y<sub>i</sub> represent are + * ordered, so the comparisons greater than, less than, and equal to are + * meaningful.</li> + * </ul> + * </p> + * + * @param x the first sample + * @param y the second sample + * @param exactPValue + * if the exact p-value is wanted (only works for x.length <= 30, + * if true and x.length > 30, this is ignored because + * calculations may take too long) + * @return p-value + * @throws NullArgumentException if {@code x} or {@code y} are {@code null}. + * @throws NoDataException if {@code x} or {@code y} are zero-length. + * @throws DimensionMismatchException if {@code x} and {@code y} do not + * have the same length. + * @throws NumberIsTooLargeException if {@code exactPValue} is {@code true} + * and {@code x.length} > 30 + * @throws ConvergenceException if the p-value can not be computed due to + * a convergence error + * @throws MaxCountExceededException if the maximum number of iterations + * is exceeded + */ + public double wilcoxonSignedRankTest(final double[] x, final double[] y, + final boolean exactPValue) + throws NullArgumentException, NoDataException, DimensionMismatchException, + NumberIsTooLargeException, ConvergenceException, MaxCountExceededException { + + ensureDataConformance(x, y); + + final int N = x.length; + final double Wmax = wilcoxonSignedRank(x, y); + + if (exactPValue && N > 30) { + throw new NumberIsTooLargeException(N, 30, true); + } + + if (exactPValue) { + return calculateExactPValue(Wmax, N); + } else { + final double Wmin = ( (double)(N*(N+1)) / 2.0 ) - Wmax; + return calculateAsymptoticPValue(Wmin, N); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/inference/package-info.java b/src/main/java/org/apache/commons/math3/stat/inference/package-info.java new file mode 100644 index 0000000..a36a080 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/inference/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Classes providing hypothesis testing. + * + */ +package org.apache.commons.math3.stat.inference; diff --git a/src/main/java/org/apache/commons/math3/stat/interval/AgrestiCoullInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/AgrestiCoullInterval.java new file mode 100644 index 0000000..b71c718 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/AgrestiCoullInterval.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.util.FastMath; + +/** + * Implements the Agresti-Coull method for creating a binomial proportion confidence interval. + * + * @see <a + * href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Agresti-Coull_Interval"> + * Agresti-Coull interval (Wikipedia)</a> + * @since 3.3 + */ +public class AgrestiCoullInterval implements BinomialConfidenceInterval { + + /** {@inheritDoc} */ + public ConfidenceInterval createInterval(int numberOfTrials, int numberOfSuccesses, double confidenceLevel) { + IntervalUtils.checkParameters(numberOfTrials, numberOfSuccesses, confidenceLevel); + final double alpha = (1.0 - confidenceLevel) / 2; + final NormalDistribution normalDistribution = new NormalDistribution(); + final double z = normalDistribution.inverseCumulativeProbability(1 - alpha); + final double zSquared = FastMath.pow(z, 2); + final double modifiedNumberOfTrials = numberOfTrials + zSquared; + final double modifiedSuccessesRatio = (1.0 / modifiedNumberOfTrials) * (numberOfSuccesses + 0.5 * zSquared); + final double difference = z * + FastMath.sqrt(1.0 / modifiedNumberOfTrials * modifiedSuccessesRatio * + (1 - modifiedSuccessesRatio)); + return new ConfidenceInterval(modifiedSuccessesRatio - difference, modifiedSuccessesRatio + difference, + confidenceLevel); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/BinomialConfidenceInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/BinomialConfidenceInterval.java new file mode 100644 index 0000000..532679a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/BinomialConfidenceInterval.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; + +/** + * Interface to generate confidence intervals for a binomial proportion. + * + * @see <a + * href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval">Binomial + * proportion confidence interval (Wikipedia)</a> + * @since 3.3 + */ +public interface BinomialConfidenceInterval { + + /** + * Create a confidence interval for the true probability of success + * of an unknown binomial distribution with the given observed number + * of trials, successes and confidence level. + * <p> + * Preconditions: + * <ul> + * <li>{@code numberOfTrials} must be positive</li> + * <li>{@code numberOfSuccesses} may not exceed {@code numberOfTrials}</li> + * <li>{@code confidenceLevel} must be strictly between 0 and 1 (exclusive)</li> + * </ul> + * </p> + * + * @param numberOfTrials number of trials + * @param numberOfSuccesses number of successes + * @param confidenceLevel desired probability that the true probability of + * success falls within the returned interval + * @return Confidence interval containing the probability of success with + * probability {@code confidenceLevel} + * @throws NotStrictlyPositiveException if {@code numberOfTrials <= 0}. + * @throws NotPositiveException if {@code numberOfSuccesses < 0}. + * @throws NumberIsTooLargeException if {@code numberOfSuccesses > numberOfTrials}. + * @throws OutOfRangeException if {@code confidenceLevel} is not in the interval {@code (0, 1)}. + */ + ConfidenceInterval createInterval(int numberOfTrials, int numberOfSuccesses, double confidenceLevel) + throws NotStrictlyPositiveException, NotPositiveException, + NumberIsTooLargeException, OutOfRangeException; + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/ClopperPearsonInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/ClopperPearsonInterval.java new file mode 100644 index 0000000..34a0d57 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/ClopperPearsonInterval.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.distribution.FDistribution; + +/** + * Implements the Clopper-Pearson method for creating a binomial proportion confidence interval. + * + * @see <a + * href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Clopper-Pearson_interval"> + * Clopper-Pearson interval (Wikipedia)</a> + * @since 3.3 + */ +public class ClopperPearsonInterval implements BinomialConfidenceInterval { + + /** {@inheritDoc} */ + public ConfidenceInterval createInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + IntervalUtils.checkParameters(numberOfTrials, numberOfSuccesses, confidenceLevel); + double lowerBound = 0; + double upperBound = 0; + final double alpha = (1.0 - confidenceLevel) / 2.0; + + final FDistribution distributionLowerBound = new FDistribution(2 * (numberOfTrials - numberOfSuccesses + 1), + 2 * numberOfSuccesses); + final double fValueLowerBound = distributionLowerBound.inverseCumulativeProbability(1 - alpha); + if (numberOfSuccesses > 0) { + lowerBound = numberOfSuccesses / + (numberOfSuccesses + (numberOfTrials - numberOfSuccesses + 1) * fValueLowerBound); + } + + final FDistribution distributionUpperBound = new FDistribution(2 * (numberOfSuccesses + 1), + 2 * (numberOfTrials - numberOfSuccesses)); + final double fValueUpperBound = distributionUpperBound.inverseCumulativeProbability(1 - alpha); + if (numberOfSuccesses > 0) { + upperBound = (numberOfSuccesses + 1) * fValueUpperBound / + (numberOfTrials - numberOfSuccesses + (numberOfSuccesses + 1) * fValueUpperBound); + } + + return new ConfidenceInterval(lowerBound, upperBound, confidenceLevel); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/ConfidenceInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/ConfidenceInterval.java new file mode 100644 index 0000000..0147c8c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/ConfidenceInterval.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; + +/** + * Represents an interval estimate of a population parameter. + * + * @since 3.3 + */ +public class ConfidenceInterval { + + /** Lower endpoint of the interval */ + private double lowerBound; + + /** Upper endpoint of the interval */ + private double upperBound; + + /** + * The asserted probability that the interval contains the population + * parameter + */ + private double confidenceLevel; + + /** + * Create a confidence interval with the given bounds and confidence level. + * <p> + * Preconditions: + * <ul> + * <li>{@code lower} must be strictly less than {@code upper}</li> + * <li>{@code confidenceLevel} must be strictly between 0 and 1 (exclusive)</li> + * </ul> + * </p> + * + * @param lowerBound lower endpoint of the interval + * @param upperBound upper endpoint of the interval + * @param confidenceLevel coverage probability + * @throws MathIllegalArgumentException if the preconditions are not met + */ + public ConfidenceInterval(double lowerBound, double upperBound, double confidenceLevel) { + checkParameters(lowerBound, upperBound, confidenceLevel); + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.confidenceLevel = confidenceLevel; + } + + /** + * @return the lower endpoint of the interval + */ + public double getLowerBound() { + return lowerBound; + } + + /** + * @return the upper endpoint of the interval + */ + public double getUpperBound() { + return upperBound; + } + + /** + * @return the asserted probability that the interval contains the + * population parameter + */ + public double getConfidenceLevel() { + return confidenceLevel; + } + + /** + * @return String representation of the confidence interval + */ + @Override + public String toString() { + return "[" + lowerBound + ";" + upperBound + "] (confidence level:" + confidenceLevel + ")"; + } + + /** + * Verifies that (lower, upper) is a valid non-empty interval and confidence + * is strictly between 0 and 1. + * + * @param lower lower endpoint + * @param upper upper endpoint + * @param confidence confidence level + */ + private void checkParameters(double lower, double upper, double confidence) { + if (lower >= upper) { + throw new MathIllegalArgumentException(LocalizedFormats.LOWER_BOUND_NOT_BELOW_UPPER_BOUND, lower, upper); + } + if (confidence <= 0 || confidence >= 1) { + throw new MathIllegalArgumentException(LocalizedFormats.OUT_OF_BOUNDS_CONFIDENCE_LEVEL, confidence, 0, 1); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/IntervalUtils.java b/src/main/java/org/apache/commons/math3/stat/interval/IntervalUtils.java new file mode 100644 index 0000000..0613c99 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/IntervalUtils.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; + +/** + * Factory methods to generate confidence intervals for a binomial proportion. + * The supported methods are: + * <ul> + * <li>Agresti-Coull interval</li> + * <li>Clopper-Pearson method (exact method)</li> + * <li>Normal approximation (based on central limit theorem)</li> + * <li>Wilson score interval</li> + * </ul> + * + * @since 3.3 + */ +public final class IntervalUtils { + + /** Singleton Agresti-Coull instance. */ + private static final BinomialConfidenceInterval AGRESTI_COULL = new AgrestiCoullInterval(); + + /** Singleton Clopper-Pearson instance. */ + private static final BinomialConfidenceInterval CLOPPER_PEARSON = new ClopperPearsonInterval(); + + /** Singleton NormalApproximation instance. */ + private static final BinomialConfidenceInterval NORMAL_APPROXIMATION = new NormalApproximationInterval(); + + /** Singleton Wilson score instance. */ + private static final BinomialConfidenceInterval WILSON_SCORE = new WilsonScoreInterval(); + + /** + * Prevent instantiation. + */ + private IntervalUtils() { + } + + /** + * Create an Agresti-Coull binomial confidence interval for the true + * probability of success of an unknown binomial distribution with the given + * observed number of trials, successes and confidence level. + * + * @param numberOfTrials number of trials + * @param numberOfSuccesses number of successes + * @param confidenceLevel desired probability that the true probability of + * success falls within the returned interval + * @return Confidence interval containing the probability of success with + * probability {@code confidenceLevel} + * @throws NotStrictlyPositiveException if {@code numberOfTrials <= 0}. + * @throws NotPositiveException if {@code numberOfSuccesses < 0}. + * @throws NumberIsTooLargeException if {@code numberOfSuccesses > numberOfTrials}. + * @throws OutOfRangeException if {@code confidenceLevel} is not in the interval {@code (0, 1)}. + */ + public static ConfidenceInterval getAgrestiCoullInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + return AGRESTI_COULL.createInterval(numberOfTrials, numberOfSuccesses, confidenceLevel); + } + + /** + * Create a Clopper-Pearson binomial confidence interval for the true + * probability of success of an unknown binomial distribution with the given + * observed number of trials, successes and confidence level. + * <p> + * Preconditions: + * <ul> + * <li>{@code numberOfTrials} must be positive</li> + * <li>{@code numberOfSuccesses} may not exceed {@code numberOfTrials}</li> + * <li>{@code confidenceLevel} must be strictly between 0 and 1 (exclusive)</li> + * </ul> + * </p> + * + * @param numberOfTrials number of trials + * @param numberOfSuccesses number of successes + * @param confidenceLevel desired probability that the true probability of + * success falls within the returned interval + * @return Confidence interval containing the probability of success with + * probability {@code confidenceLevel} + * @throws NotStrictlyPositiveException if {@code numberOfTrials <= 0}. + * @throws NotPositiveException if {@code numberOfSuccesses < 0}. + * @throws NumberIsTooLargeException if {@code numberOfSuccesses > numberOfTrials}. + * @throws OutOfRangeException if {@code confidenceLevel} is not in the interval {@code (0, 1)}. + */ + public static ConfidenceInterval getClopperPearsonInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + return CLOPPER_PEARSON.createInterval(numberOfTrials, numberOfSuccesses, confidenceLevel); + } + + /** + * Create a binomial confidence interval for the true probability of success + * of an unknown binomial distribution with the given observed number of + * trials, successes and confidence level using the Normal approximation to + * the binomial distribution. + * + * @param numberOfTrials number of trials + * @param numberOfSuccesses number of successes + * @param confidenceLevel desired probability that the true probability of + * success falls within the interval + * @return Confidence interval containing the probability of success with + * probability {@code confidenceLevel} + */ + public static ConfidenceInterval getNormalApproximationInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + return NORMAL_APPROXIMATION.createInterval(numberOfTrials, numberOfSuccesses, confidenceLevel); + } + + /** + * Create a Wilson score binomial confidence interval for the true + * probability of success of an unknown binomial distribution with the given + * observed number of trials, successes and confidence level. + * + * @param numberOfTrials number of trials + * @param numberOfSuccesses number of successes + * @param confidenceLevel desired probability that the true probability of + * success falls within the returned interval + * @return Confidence interval containing the probability of success with + * probability {@code confidenceLevel} + * @throws NotStrictlyPositiveException if {@code numberOfTrials <= 0}. + * @throws NotPositiveException if {@code numberOfSuccesses < 0}. + * @throws NumberIsTooLargeException if {@code numberOfSuccesses > numberOfTrials}. + * @throws OutOfRangeException if {@code confidenceLevel} is not in the interval {@code (0, 1)}. + */ + public static ConfidenceInterval getWilsonScoreInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + return WILSON_SCORE.createInterval(numberOfTrials, numberOfSuccesses, confidenceLevel); + } + + /** + * Verifies that parameters satisfy preconditions. + * + * @param numberOfTrials number of trials (must be positive) + * @param numberOfSuccesses number of successes (must not exceed numberOfTrials) + * @param confidenceLevel confidence level (must be strictly between 0 and 1) + * @throws NotStrictlyPositiveException if {@code numberOfTrials <= 0}. + * @throws NotPositiveException if {@code numberOfSuccesses < 0}. + * @throws NumberIsTooLargeException if {@code numberOfSuccesses > numberOfTrials}. + * @throws OutOfRangeException if {@code confidenceLevel} is not in the interval {@code (0, 1)}. + */ + static void checkParameters(int numberOfTrials, int numberOfSuccesses, double confidenceLevel) { + if (numberOfTrials <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_TRIALS, numberOfTrials); + } + if (numberOfSuccesses < 0) { + throw new NotPositiveException(LocalizedFormats.NEGATIVE_NUMBER_OF_SUCCESSES, numberOfSuccesses); + } + if (numberOfSuccesses > numberOfTrials) { + throw new NumberIsTooLargeException(LocalizedFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE, + numberOfSuccesses, numberOfTrials, true); + } + if (confidenceLevel <= 0 || confidenceLevel >= 1) { + throw new OutOfRangeException(LocalizedFormats.OUT_OF_BOUNDS_CONFIDENCE_LEVEL, + confidenceLevel, 0, 1); + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/NormalApproximationInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/NormalApproximationInterval.java new file mode 100644 index 0000000..25a213a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/NormalApproximationInterval.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.util.FastMath; + +/** + * Implements the normal approximation method for creating a binomial proportion confidence interval. + * + * @see <a + * href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Normal_approximation_interval"> + * Normal approximation interval (Wikipedia)</a> + * @since 3.3 + */ +public class NormalApproximationInterval implements BinomialConfidenceInterval { + + /** {@inheritDoc} */ + public ConfidenceInterval createInterval(int numberOfTrials, int numberOfSuccesses, + double confidenceLevel) { + IntervalUtils.checkParameters(numberOfTrials, numberOfSuccesses, confidenceLevel); + final double mean = (double) numberOfSuccesses / (double) numberOfTrials; + final double alpha = (1.0 - confidenceLevel) / 2; + final NormalDistribution normalDistribution = new NormalDistribution(); + final double difference = normalDistribution.inverseCumulativeProbability(1 - alpha) * + FastMath.sqrt(1.0 / numberOfTrials * mean * (1 - mean)); + return new ConfidenceInterval(mean - difference, mean + difference, confidenceLevel); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/WilsonScoreInterval.java b/src/main/java/org/apache/commons/math3/stat/interval/WilsonScoreInterval.java new file mode 100644 index 0000000..9932835 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/WilsonScoreInterval.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.interval; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.util.FastMath; + +/** + * Implements the Wilson score method for creating a binomial proportion confidence interval. + * + * @see <a + * href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval"> + * Wilson score interval (Wikipedia)</a> + * @since 3.3 + */ +public class WilsonScoreInterval implements BinomialConfidenceInterval { + + /** {@inheritDoc} */ + public ConfidenceInterval createInterval(int numberOfTrials, int numberOfSuccesses, double confidenceLevel) { + IntervalUtils.checkParameters(numberOfTrials, numberOfSuccesses, confidenceLevel); + final double alpha = (1.0 - confidenceLevel) / 2; + final NormalDistribution normalDistribution = new NormalDistribution(); + final double z = normalDistribution.inverseCumulativeProbability(1 - alpha); + final double zSquared = FastMath.pow(z, 2); + final double mean = (double) numberOfSuccesses / (double) numberOfTrials; + + final double factor = 1.0 / (1 + (1.0 / numberOfTrials) * zSquared); + final double modifiedSuccessRatio = mean + (1.0 / (2 * numberOfTrials)) * zSquared; + final double difference = z * + FastMath.sqrt(1.0 / numberOfTrials * mean * (1 - mean) + + (1.0 / (4 * FastMath.pow(numberOfTrials, 2)) * zSquared)); + + final double lowerBound = factor * (modifiedSuccessRatio - difference); + final double upperBound = factor * (modifiedSuccessRatio + difference); + return new ConfidenceInterval(lowerBound, upperBound, confidenceLevel); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/interval/package-info.java b/src/main/java/org/apache/commons/math3/stat/interval/package-info.java new file mode 100644 index 0000000..34f43d9 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/interval/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Classes providing binomial proportion confidence interval construction. + * + */ +package org.apache.commons.math3.stat.interval; diff --git a/src/main/java/org/apache/commons/math3/stat/package-info.java b/src/main/java/org/apache/commons/math3/stat/package-info.java new file mode 100644 index 0000000..1df9698 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/package-info.java @@ -0,0 +1,18 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** Data storage, manipulation and summary routines. */ +package org.apache.commons.math3.stat; diff --git a/src/main/java/org/apache/commons/math3/stat/ranking/NaNStrategy.java b/src/main/java/org/apache/commons/math3/stat/ranking/NaNStrategy.java new file mode 100644 index 0000000..1a916ef --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/ranking/NaNStrategy.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.ranking; + +/** + * Strategies for handling NaN values in rank transformations. + * <ul> + * <li>MINIMAL - NaNs are treated as minimal in the ordering, equivalent to + * (that is, tied with) <code>Double.NEGATIVE_INFINITY</code>.</li> + * <li>MAXIMAL - NaNs are treated as maximal in the ordering, equivalent to + * <code>Double.POSITIVE_INFINITY</code></li> + * <li>REMOVED - NaNs are removed before the rank transform is applied</li> + * <li>FIXED - NaNs are left "in place," that is the rank transformation is + * applied to the other elements in the input array, but the NaN elements + * are returned unchanged.</li> + * <li>FAILED - If any NaN is encountered in the input array, an appropriate + * exception is thrown</li> + * </ul> + * + * @since 2.0 + */ +public enum NaNStrategy { + + /** NaNs are considered minimal in the ordering */ + MINIMAL, + + /** NaNs are considered maximal in the ordering */ + MAXIMAL, + + /** NaNs are removed before computing ranks */ + REMOVED, + + /** NaNs are left in place */ + FIXED, + + /** NaNs result in an exception + * @since 3.1 + */ + FAILED +} diff --git a/src/main/java/org/apache/commons/math3/stat/ranking/NaturalRanking.java b/src/main/java/org/apache/commons/math3/stat/ranking/NaturalRanking.java new file mode 100644 index 0000000..6107c46 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/ranking/NaturalRanking.java @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.ranking; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.math3.exception.MathInternalError; +import org.apache.commons.math3.exception.NotANumberException; +import org.apache.commons.math3.random.RandomDataGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.util.FastMath; + + +/** + * <p> Ranking based on the natural ordering on doubles.</p> + * <p>NaNs are treated according to the configured {@link NaNStrategy} and ties + * are handled using the selected {@link TiesStrategy}. + * Configuration settings are supplied in optional constructor arguments. + * Defaults are {@link NaNStrategy#FAILED} and {@link TiesStrategy#AVERAGE}, + * respectively. When using {@link TiesStrategy#RANDOM}, a + * {@link RandomGenerator} may be supplied as a constructor argument.</p> + * <p>Examples: + * <table border="1" cellpadding="3"> + * <tr><th colspan="3"> + * Input data: (20, 17, 30, 42.3, 17, 50, Double.NaN, Double.NEGATIVE_INFINITY, 17) + * </th></tr> + * <tr><th>NaNStrategy</th><th>TiesStrategy</th> + * <th><code>rank(data)</code></th> + * <tr> + * <td>default (NaNs maximal)</td> + * <td>default (ties averaged)</td> + * <td>(5, 3, 6, 7, 3, 8, 9, 1, 3)</td></tr> + * <tr> + * <td>default (NaNs maximal)</td> + * <td>MINIMUM</td> + * <td>(5, 2, 6, 7, 2, 8, 9, 1, 2)</td></tr> + * <tr> + * <td>MINIMAL</td> + * <td>default (ties averaged)</td> + * <td>(6, 4, 7, 8, 4, 9, 1.5, 1.5, 4)</td></tr> + * <tr> + * <td>REMOVED</td> + * <td>SEQUENTIAL</td> + * <td>(5, 2, 6, 7, 3, 8, 1, 4)</td></tr> + * <tr> + * <td>MINIMAL</td> + * <td>MAXIMUM</td> + * <td>(6, 5, 7, 8, 5, 9, 2, 2, 5)</td></tr></table></p> + * + * @since 2.0 + */ +public class NaturalRanking implements RankingAlgorithm { + + /** default NaN strategy */ + public static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED; + + /** default ties strategy */ + public static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE; + + /** NaN strategy - defaults to NaNs maximal */ + private final NaNStrategy nanStrategy; + + /** Ties strategy - defaults to ties averaged */ + private final TiesStrategy tiesStrategy; + + /** Source of random data - used only when ties strategy is RANDOM */ + private final RandomDataGenerator randomData; + + /** + * Create a NaturalRanking with default strategies for handling ties and NaNs. + */ + public NaturalRanking() { + super(); + tiesStrategy = DEFAULT_TIES_STRATEGY; + nanStrategy = DEFAULT_NAN_STRATEGY; + randomData = null; + } + + /** + * Create a NaturalRanking with the given TiesStrategy. + * + * @param tiesStrategy the TiesStrategy to use + */ + public NaturalRanking(TiesStrategy tiesStrategy) { + super(); + this.tiesStrategy = tiesStrategy; + nanStrategy = DEFAULT_NAN_STRATEGY; + randomData = new RandomDataGenerator(); + } + + /** + * Create a NaturalRanking with the given NaNStrategy. + * + * @param nanStrategy the NaNStrategy to use + */ + public NaturalRanking(NaNStrategy nanStrategy) { + super(); + this.nanStrategy = nanStrategy; + tiesStrategy = DEFAULT_TIES_STRATEGY; + randomData = null; + } + + /** + * Create a NaturalRanking with the given NaNStrategy and TiesStrategy. + * + * @param nanStrategy NaNStrategy to use + * @param tiesStrategy TiesStrategy to use + */ + public NaturalRanking(NaNStrategy nanStrategy, TiesStrategy tiesStrategy) { + super(); + this.nanStrategy = nanStrategy; + this.tiesStrategy = tiesStrategy; + randomData = new RandomDataGenerator(); + } + + /** + * Create a NaturalRanking with TiesStrategy.RANDOM and the given + * RandomGenerator as the source of random data. + * + * @param randomGenerator source of random data + */ + public NaturalRanking(RandomGenerator randomGenerator) { + super(); + this.tiesStrategy = TiesStrategy.RANDOM; + nanStrategy = DEFAULT_NAN_STRATEGY; + randomData = new RandomDataGenerator(randomGenerator); + } + + + /** + * Create a NaturalRanking with the given NaNStrategy, TiesStrategy.RANDOM + * and the given source of random data. + * + * @param nanStrategy NaNStrategy to use + * @param randomGenerator source of random data + */ + public NaturalRanking(NaNStrategy nanStrategy, + RandomGenerator randomGenerator) { + super(); + this.nanStrategy = nanStrategy; + this.tiesStrategy = TiesStrategy.RANDOM; + randomData = new RandomDataGenerator(randomGenerator); + } + + /** + * Return the NaNStrategy + * + * @return returns the NaNStrategy + */ + public NaNStrategy getNanStrategy() { + return nanStrategy; + } + + /** + * Return the TiesStrategy + * + * @return the TiesStrategy + */ + public TiesStrategy getTiesStrategy() { + return tiesStrategy; + } + + /** + * Rank <code>data</code> using the natural ordering on Doubles, with + * NaN values handled according to <code>nanStrategy</code> and ties + * resolved using <code>tiesStrategy.</code> + * + * @param data array to be ranked + * @return array of ranks + * @throws NotANumberException if the selected {@link NaNStrategy} is {@code FAILED} + * and a {@link Double#NaN} is encountered in the input data + */ + public double[] rank(double[] data) { + + // Array recording initial positions of data to be ranked + IntDoublePair[] ranks = new IntDoublePair[data.length]; + for (int i = 0; i < data.length; i++) { + ranks[i] = new IntDoublePair(data[i], i); + } + + // Recode, remove or record positions of NaNs + List<Integer> nanPositions = null; + switch (nanStrategy) { + case MAXIMAL: // Replace NaNs with +INFs + recodeNaNs(ranks, Double.POSITIVE_INFINITY); + break; + case MINIMAL: // Replace NaNs with -INFs + recodeNaNs(ranks, Double.NEGATIVE_INFINITY); + break; + case REMOVED: // Drop NaNs from data + ranks = removeNaNs(ranks); + break; + case FIXED: // Record positions of NaNs + nanPositions = getNanPositions(ranks); + break; + case FAILED: + nanPositions = getNanPositions(ranks); + if (nanPositions.size() > 0) { + throw new NotANumberException(); + } + break; + default: // this should not happen unless NaNStrategy enum is changed + throw new MathInternalError(); + } + + // Sort the IntDoublePairs + Arrays.sort(ranks); + + // Walk the sorted array, filling output array using sorted positions, + // resolving ties as we go + double[] out = new double[ranks.length]; + int pos = 1; // position in sorted array + out[ranks[0].getPosition()] = pos; + List<Integer> tiesTrace = new ArrayList<Integer>(); + tiesTrace.add(ranks[0].getPosition()); + for (int i = 1; i < ranks.length; i++) { + if (Double.compare(ranks[i].getValue(), ranks[i - 1].getValue()) > 0) { + // tie sequence has ended (or had length 1) + pos = i + 1; + if (tiesTrace.size() > 1) { // if seq is nontrivial, resolve + resolveTie(out, tiesTrace); + } + tiesTrace = new ArrayList<Integer>(); + tiesTrace.add(ranks[i].getPosition()); + } else { + // tie sequence continues + tiesTrace.add(ranks[i].getPosition()); + } + out[ranks[i].getPosition()] = pos; + } + if (tiesTrace.size() > 1) { // handle tie sequence at end + resolveTie(out, tiesTrace); + } + if (nanStrategy == NaNStrategy.FIXED) { + restoreNaNs(out, nanPositions); + } + return out; + } + + /** + * Returns an array that is a copy of the input array with IntDoublePairs + * having NaN values removed. + * + * @param ranks input array + * @return array with NaN-valued entries removed + */ + private IntDoublePair[] removeNaNs(IntDoublePair[] ranks) { + if (!containsNaNs(ranks)) { + return ranks; + } + IntDoublePair[] outRanks = new IntDoublePair[ranks.length]; + int j = 0; + for (int i = 0; i < ranks.length; i++) { + if (Double.isNaN(ranks[i].getValue())) { + // drop, but adjust original ranks of later elements + for (int k = i + 1; k < ranks.length; k++) { + ranks[k] = new IntDoublePair( + ranks[k].getValue(), ranks[k].getPosition() - 1); + } + } else { + outRanks[j] = new IntDoublePair( + ranks[i].getValue(), ranks[i].getPosition()); + j++; + } + } + IntDoublePair[] returnRanks = new IntDoublePair[j]; + System.arraycopy(outRanks, 0, returnRanks, 0, j); + return returnRanks; + } + + /** + * Recodes NaN values to the given value. + * + * @param ranks array to recode + * @param value the value to replace NaNs with + */ + private void recodeNaNs(IntDoublePair[] ranks, double value) { + for (int i = 0; i < ranks.length; i++) { + if (Double.isNaN(ranks[i].getValue())) { + ranks[i] = new IntDoublePair( + value, ranks[i].getPosition()); + } + } + } + + /** + * Checks for presence of NaNs in <code>ranks.</code> + * + * @param ranks array to be searched for NaNs + * @return true iff ranks contains one or more NaNs + */ + private boolean containsNaNs(IntDoublePair[] ranks) { + for (int i = 0; i < ranks.length; i++) { + if (Double.isNaN(ranks[i].getValue())) { + return true; + } + } + return false; + } + + /** + * Resolve a sequence of ties, using the configured {@link TiesStrategy}. + * The input <code>ranks</code> array is expected to take the same value + * for all indices in <code>tiesTrace</code>. The common value is recoded + * according to the tiesStrategy. For example, if ranks = <5,8,2,6,2,7,1,2>, + * tiesTrace = <2,4,7> and tiesStrategy is MINIMUM, ranks will be unchanged. + * The same array and trace with tiesStrategy AVERAGE will come out + * <5,8,3,6,3,7,1,3>. + * + * @param ranks array of ranks + * @param tiesTrace list of indices where <code>ranks</code> is constant + * -- that is, for any i and j in TiesTrace, <code> ranks[i] == ranks[j] + * </code> + */ + private void resolveTie(double[] ranks, List<Integer> tiesTrace) { + + // constant value of ranks over tiesTrace + final double c = ranks[tiesTrace.get(0)]; + + // length of sequence of tied ranks + final int length = tiesTrace.size(); + + switch (tiesStrategy) { + case AVERAGE: // Replace ranks with average + fill(ranks, tiesTrace, (2 * c + length - 1) / 2d); + break; + case MAXIMUM: // Replace ranks with maximum values + fill(ranks, tiesTrace, c + length - 1); + break; + case MINIMUM: // Replace ties with minimum + fill(ranks, tiesTrace, c); + break; + case RANDOM: // Fill with random integral values in [c, c + length - 1] + Iterator<Integer> iterator = tiesTrace.iterator(); + long f = FastMath.round(c); + while (iterator.hasNext()) { + // No advertised exception because args are guaranteed valid + ranks[iterator.next()] = + randomData.nextLong(f, f + length - 1); + } + break; + case SEQUENTIAL: // Fill sequentially from c to c + length - 1 + // walk and fill + iterator = tiesTrace.iterator(); + f = FastMath.round(c); + int i = 0; + while (iterator.hasNext()) { + ranks[iterator.next()] = f + i++; + } + break; + default: // this should not happen unless TiesStrategy enum is changed + throw new MathInternalError(); + } + } + + /** + * Sets<code>data[i] = value</code> for each i in <code>tiesTrace.</code> + * + * @param data array to modify + * @param tiesTrace list of index values to set + * @param value value to set + */ + private void fill(double[] data, List<Integer> tiesTrace, double value) { + Iterator<Integer> iterator = tiesTrace.iterator(); + while (iterator.hasNext()) { + data[iterator.next()] = value; + } + } + + /** + * Set <code>ranks[i] = Double.NaN</code> for each i in <code>nanPositions.</code> + * + * @param ranks array to modify + * @param nanPositions list of index values to set to <code>Double.NaN</code> + */ + private void restoreNaNs(double[] ranks, List<Integer> nanPositions) { + if (nanPositions.size() == 0) { + return; + } + Iterator<Integer> iterator = nanPositions.iterator(); + while (iterator.hasNext()) { + ranks[iterator.next().intValue()] = Double.NaN; + } + + } + + /** + * Returns a list of indexes where <code>ranks</code> is <code>NaN.</code> + * + * @param ranks array to search for <code>NaNs</code> + * @return list of indexes i such that <code>ranks[i] = NaN</code> + */ + private List<Integer> getNanPositions(IntDoublePair[] ranks) { + ArrayList<Integer> out = new ArrayList<Integer>(); + for (int i = 0; i < ranks.length; i++) { + if (Double.isNaN(ranks[i].getValue())) { + out.add(Integer.valueOf(i)); + } + } + return out; + } + + /** + * Represents the position of a double value in an ordering. + * Comparable interface is implemented so Arrays.sort can be used + * to sort an array of IntDoublePairs by value. Note that the + * implicitly defined natural ordering is NOT consistent with equals. + */ + private static class IntDoublePair implements Comparable<IntDoublePair> { + + /** Value of the pair */ + private final double value; + + /** Original position of the pair */ + private final int position; + + /** + * Construct an IntDoublePair with the given value and position. + * @param value the value of the pair + * @param position the original position + */ + IntDoublePair(double value, int position) { + this.value = value; + this.position = position; + } + + /** + * Compare this IntDoublePair to another pair. + * Only the <strong>values</strong> are compared. + * + * @param other the other pair to compare this to + * @return result of <code>Double.compare(value, other.value)</code> + */ + public int compareTo(IntDoublePair other) { + return Double.compare(value, other.value); + } + + // N.B. equals() and hashCode() are not implemented; see MATH-610 for discussion. + + /** + * Returns the value of the pair. + * @return value + */ + public double getValue() { + return value; + } + + /** + * Returns the original position of the pair. + * @return position + */ + public int getPosition() { + return position; + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/ranking/RankingAlgorithm.java b/src/main/java/org/apache/commons/math3/stat/ranking/RankingAlgorithm.java new file mode 100644 index 0000000..188bc99 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/ranking/RankingAlgorithm.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.ranking; + +/** + * Interface representing a rank transformation. + * + * @since 2.0 + */ +public interface RankingAlgorithm { + /** + * <p>Performs a rank transformation on the input data, returning an array + * of ranks.</p> + * + * <p>Ranks should be 1-based - that is, the smallest value + * returned in an array of ranks should be greater than or equal to one, + * rather than 0. Ranks should in general take integer values, though + * implementations may return averages or other floating point values + * to resolve ties in the input data.</p> + * + * @param data array of data to be ranked + * @return an array of ranks corresponding to the elements of the input array + */ + double[] rank (double[] data); +} diff --git a/src/main/java/org/apache/commons/math3/stat/ranking/TiesStrategy.java b/src/main/java/org/apache/commons/math3/stat/ranking/TiesStrategy.java new file mode 100644 index 0000000..08ab99a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/ranking/TiesStrategy.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.ranking; + +/** + * Strategies for handling tied values in rank transformations. + * <ul> + * <li>SEQUENTIAL - Ties are assigned ranks in order of occurrence in the original array, + * for example (1,3,4,3) is ranked as (1,2,4,3)</li> + * <li>MINIMUM - Tied values are assigned the minimum applicable rank, or the rank + * of the first occurrence. For example, (1,3,4,3) is ranked as (1,2,4,2)</li> + * <li>MAXIMUM - Tied values are assigned the maximum applicable rank, or the rank + * of the last occurrence. For example, (1,3,4,3) is ranked as (1,3,4,3)</li> + * <li>AVERAGE - Tied values are assigned the average of the applicable ranks. + * For example, (1,3,4,3) is ranked as (1,2.5,4,2.5)</li> + * <li>RANDOM - Tied values are assigned a random integer rank from among the + * applicable values. The assigned rank will always be an integer, (inclusively) + * between the values returned by the MINIMUM and MAXIMUM strategies.</li> + * </ul> + * + * @since 2.0 + */ +public enum TiesStrategy { + + /** Ties assigned sequential ranks in order of occurrence */ + SEQUENTIAL, + + /** Ties get the minimum applicable rank */ + MINIMUM, + + /** Ties get the maximum applicable rank */ + MAXIMUM, + + /** Ties get the average of applicable ranks */ + AVERAGE, + + /** Ties get a random integral value from among applicable ranks */ + RANDOM +} diff --git a/src/main/java/org/apache/commons/math3/stat/ranking/package-info.java b/src/main/java/org/apache/commons/math3/stat/ranking/package-info.java new file mode 100644 index 0000000..b86575b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/ranking/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Classes providing rank transformations. + * + */ +package org.apache.commons.math3.stat.ranking; diff --git a/src/main/java/org/apache/commons/math3/stat/regression/AbstractMultipleLinearRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/AbstractMultipleLinearRegression.java new file mode 100644 index 0000000..9b7c40a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/AbstractMultipleLinearRegression.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.InsufficientDataException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.linear.NonSquareMatrixException; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.util.FastMath; + +/** + * Abstract base class for implementations of MultipleLinearRegression. + * @since 2.0 + */ +public abstract class AbstractMultipleLinearRegression implements + MultipleLinearRegression { + + /** X sample data. */ + private RealMatrix xMatrix; + + /** Y sample data. */ + private RealVector yVector; + + /** Whether or not the regression model includes an intercept. True means no intercept. */ + private boolean noIntercept = false; + + /** + * @return the X sample data. + */ + protected RealMatrix getX() { + return xMatrix; + } + + /** + * @return the Y sample data. + */ + protected RealVector getY() { + return yVector; + } + + /** + * @return true if the model has no intercept term; false otherwise + * @since 2.2 + */ + public boolean isNoIntercept() { + return noIntercept; + } + + /** + * @param noIntercept true means the model is to be estimated without an intercept term + * @since 2.2 + */ + public void setNoIntercept(boolean noIntercept) { + this.noIntercept = noIntercept; + } + + /** + * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. + * </p> + * <p>Assumes that rows are concatenated with y values first in each row. For example, an input + * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with + * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two + * independent variables, as below: + * <pre> + * y x[0] x[1] + * -------------- + * 1 2 3 + * 4 5 6 + * 7 8 9 + * </pre> + * </p> + * <p>Note that there is no need to add an initial unitary column (column of 1's) when + * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, + * the X matrix will be created without an initial column of "1"s; otherwise this column will + * be added. + * </p> + * <p>Throws IllegalArgumentException if any of the following preconditions fail: + * <ul><li><code>data</code> cannot be null</li> + * <li><code>data.length = nobs * (nvars + 1)</li> + * <li><code>nobs > nvars</code></li></ul> + * </p> + * + * @param data input data array + * @param nobs number of observations (rows) + * @param nvars number of independent variables (columns, not counting y) + * @throws NullArgumentException if the data array is null + * @throws DimensionMismatchException if the length of the data array is not equal + * to <code>nobs * (nvars + 1)</code> + * @throws InsufficientDataException if <code>nobs</code> is less than + * <code>nvars + 1</code> + */ + public void newSampleData(double[] data, int nobs, int nvars) { + if (data == null) { + throw new NullArgumentException(); + } + if (data.length != nobs * (nvars + 1)) { + throw new DimensionMismatchException(data.length, nobs * (nvars + 1)); + } + if (nobs <= nvars) { + throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1); + } + double[] y = new double[nobs]; + final int cols = noIntercept ? nvars: nvars + 1; + double[][] x = new double[nobs][cols]; + int pointer = 0; + for (int i = 0; i < nobs; i++) { + y[i] = data[pointer++]; + if (!noIntercept) { + x[i][0] = 1.0d; + } + for (int j = noIntercept ? 0 : 1; j < cols; j++) { + x[i][j] = data[pointer++]; + } + } + this.xMatrix = new Array2DRowRealMatrix(x); + this.yVector = new ArrayRealVector(y); + } + + /** + * Loads new y sample data, overriding any previous data. + * + * @param y the array representing the y sample + * @throws NullArgumentException if y is null + * @throws NoDataException if y is empty + */ + protected void newYSampleData(double[] y) { + if (y == null) { + throw new NullArgumentException(); + } + if (y.length == 0) { + throw new NoDataException(); + } + this.yVector = new ArrayRealVector(y); + } + + /** + * <p>Loads new x sample data, overriding any previous data. + * </p> + * The input <code>x</code> array should have one row for each sample + * observation, with columns corresponding to independent variables. + * For example, if <pre> + * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> + * then <code>setXSampleData(x) </code> results in a model with two independent + * variables and 3 observations: + * <pre> + * x[0] x[1] + * ---------- + * 1 2 + * 3 4 + * 5 6 + * </pre> + * </p> + * <p>Note that there is no need to add an initial unitary column (column of 1's) when + * specifying a model including an intercept term. + * </p> + * @param x the rectangular array representing the x sample + * @throws NullArgumentException if x is null + * @throws NoDataException if x is empty + * @throws DimensionMismatchException if x is not rectangular + */ + protected void newXSampleData(double[][] x) { + if (x == null) { + throw new NullArgumentException(); + } + if (x.length == 0) { + throw new NoDataException(); + } + if (noIntercept) { + this.xMatrix = new Array2DRowRealMatrix(x, true); + } else { // Augment design matrix with initial unitary column + final int nVars = x[0].length; + final double[][] xAug = new double[x.length][nVars + 1]; + for (int i = 0; i < x.length; i++) { + if (x[i].length != nVars) { + throw new DimensionMismatchException(x[i].length, nVars); + } + xAug[i][0] = 1.0d; + System.arraycopy(x[i], 0, xAug[i], 1, nVars); + } + this.xMatrix = new Array2DRowRealMatrix(xAug, false); + } + } + + /** + * Validates sample data. Checks that + * <ul><li>Neither x nor y is null or empty;</li> + * <li>The length (i.e. number of rows) of x equals the length of y</li> + * <li>x has at least one more row than it has columns (i.e. there is + * sufficient data to estimate regression coefficients for each of the + * columns in x plus an intercept.</li> + * </ul> + * + * @param x the [n,k] array representing the x data + * @param y the [n,1] array representing the y data + * @throws NullArgumentException if {@code x} or {@code y} is null + * @throws DimensionMismatchException if {@code x} and {@code y} do not + * have the same length + * @throws NoDataException if {@code x} or {@code y} are zero-length + * @throws MathIllegalArgumentException if the number of rows of {@code x} + * is not larger than the number of columns + 1 + */ + protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException { + if ((x == null) || (y == null)) { + throw new NullArgumentException(); + } + if (x.length != y.length) { + throw new DimensionMismatchException(y.length, x.length); + } + if (x.length == 0) { // Must be no y data either + throw new NoDataException(); + } + if (x[0].length + 1 > x.length) { + throw new MathIllegalArgumentException( + LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + x.length, x[0].length); + } + } + + /** + * Validates that the x data and covariance matrix have the same + * number of rows and that the covariance matrix is square. + * + * @param x the [n,k] array representing the x sample + * @param covariance the [n,n] array representing the covariance matrix + * @throws DimensionMismatchException if the number of rows in x is not equal + * to the number of rows in covariance + * @throws NonSquareMatrixException if the covariance matrix is not square + */ + protected void validateCovarianceData(double[][] x, double[][] covariance) { + if (x.length != covariance.length) { + throw new DimensionMismatchException(x.length, covariance.length); + } + if (covariance.length > 0 && covariance.length != covariance[0].length) { + throw new NonSquareMatrixException(covariance.length, covariance[0].length); + } + } + + /** + * {@inheritDoc} + */ + public double[] estimateRegressionParameters() { + RealVector b = calculateBeta(); + return b.toArray(); + } + + /** + * {@inheritDoc} + */ + public double[] estimateResiduals() { + RealVector b = calculateBeta(); + RealVector e = yVector.subtract(xMatrix.operate(b)); + return e.toArray(); + } + + /** + * {@inheritDoc} + */ + public double[][] estimateRegressionParametersVariance() { + return calculateBetaVariance().getData(); + } + + /** + * {@inheritDoc} + */ + public double[] estimateRegressionParametersStandardErrors() { + double[][] betaVariance = estimateRegressionParametersVariance(); + double sigma = calculateErrorVariance(); + int length = betaVariance[0].length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = FastMath.sqrt(sigma * betaVariance[i][i]); + } + return result; + } + + /** + * {@inheritDoc} + */ + public double estimateRegressandVariance() { + return calculateYVariance(); + } + + /** + * Estimates the variance of the error. + * + * @return estimate of the error variance + * @since 2.2 + */ + public double estimateErrorVariance() { + return calculateErrorVariance(); + + } + + /** + * Estimates the standard error of the regression. + * + * @return regression standard error + * @since 2.2 + */ + public double estimateRegressionStandardError() { + return FastMath.sqrt(estimateErrorVariance()); + } + + /** + * Calculates the beta of multiple linear regression in matrix notation. + * + * @return beta + */ + protected abstract RealVector calculateBeta(); + + /** + * Calculates the beta variance of multiple linear regression in matrix + * notation. + * + * @return beta variance + */ + protected abstract RealMatrix calculateBetaVariance(); + + + /** + * Calculates the variance of the y values. + * + * @return Y variance + */ + protected double calculateYVariance() { + return new Variance().evaluate(yVector.toArray()); + } + + /** + * <p>Calculates the variance of the error term.</p> + * Uses the formula <pre> + * var(u) = u · u / (n - k) + * </pre> + * where n and k are the row and column dimensions of the design + * matrix X. + * + * @return error variance estimate + * @since 2.2 + */ + protected double calculateErrorVariance() { + RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals) / + (xMatrix.getRowDimension() - xMatrix.getColumnDimension()); + } + + /** + * Calculates the residuals of multiple linear regression in matrix + * notation. + * + * <pre> + * u = y - X * b + * </pre> + * + * @return The residuals [n,1] matrix + */ + protected RealVector calculateResiduals() { + RealVector b = calculateBeta(); + return yVector.subtract(xMatrix.operate(b)); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/GLSMultipleLinearRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/GLSMultipleLinearRegression.java new file mode 100644 index 0000000..1644e6d --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/GLSMultipleLinearRegression.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import org.apache.commons.math3.linear.LUDecomposition; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealVector; + +/** + * The GLS implementation of multiple linear regression. + * + * GLS assumes a general covariance matrix Omega of the error + * <pre> + * u ~ N(0, Omega) + * </pre> + * + * Estimated by GLS, + * <pre> + * b=(X' Omega^-1 X)^-1X'Omega^-1 y + * </pre> + * whose variance is + * <pre> + * Var(b)=(X' Omega^-1 X)^-1 + * </pre> + * @since 2.0 + */ +public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression { + + /** Covariance matrix. */ + private RealMatrix Omega; + + /** Inverse of covariance matrix. */ + private RealMatrix OmegaInverse; + + /** Replace sample data, overriding any previous sample. + * @param y y values of the sample + * @param x x values of the sample + * @param covariance array representing the covariance matrix + */ + public void newSampleData(double[] y, double[][] x, double[][] covariance) { + validateSampleData(x, y); + newYSampleData(y); + newXSampleData(x); + validateCovarianceData(x, covariance); + newCovarianceData(covariance); + } + + /** + * Add the covariance data. + * + * @param omega the [n,n] array representing the covariance + */ + protected void newCovarianceData(double[][] omega){ + this.Omega = new Array2DRowRealMatrix(omega); + this.OmegaInverse = null; + } + + /** + * Get the inverse of the covariance. + * <p>The inverse of the covariance matrix is lazily evaluated and cached.</p> + * @return inverse of the covariance + */ + protected RealMatrix getOmegaInverse() { + if (OmegaInverse == null) { + OmegaInverse = new LUDecomposition(Omega).getSolver().getInverse(); + } + return OmegaInverse; + } + + /** + * Calculates beta by GLS. + * <pre> + * b=(X' Omega^-1 X)^-1X'Omega^-1 y + * </pre> + * @return beta + */ + @Override + protected RealVector calculateBeta() { + RealMatrix OI = getOmegaInverse(); + RealMatrix XT = getX().transpose(); + RealMatrix XTOIX = XT.multiply(OI).multiply(getX()); + RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse(); + return inverse.multiply(XT).multiply(OI).operate(getY()); + } + + /** + * Calculates the variance on the beta. + * <pre> + * Var(b)=(X' Omega^-1 X)^-1 + * </pre> + * @return The beta variance matrix + */ + @Override + protected RealMatrix calculateBetaVariance() { + RealMatrix OI = getOmegaInverse(); + RealMatrix XTOIX = getX().transpose().multiply(OI).multiply(getX()); + return new LUDecomposition(XTOIX).getSolver().getInverse(); + } + + + /** + * Calculates the estimated variance of the error term using the formula + * <pre> + * Var(u) = Tr(u' Omega^-1 u)/(n-k) + * </pre> + * where n and k are the row and column dimensions of the design + * matrix X. + * + * @return error variance + * @since 2.2 + */ + @Override + protected double calculateErrorVariance() { + RealVector residuals = calculateResiduals(); + double t = residuals.dotProduct(getOmegaInverse().operate(residuals)); + return t / (getX().getRowDimension() - getX().getColumnDimension()); + + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/MillerUpdatingRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/MillerUpdatingRegression.java new file mode 100644 index 0000000..3fe3c03 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/MillerUpdatingRegression.java @@ -0,0 +1,1101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import java.util.Arrays; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Precision; +import org.apache.commons.math3.util.MathArrays; + +/** + * This class is a concrete implementation of the {@link UpdatingMultipleLinearRegression} interface. + * + * <p>The algorithm is described in: <pre> + * Algorithm AS 274: Least Squares Routines to Supplement Those of Gentleman + * Author(s): Alan J. Miller + * Source: Journal of the Royal Statistical Society. + * Series C (Applied Statistics), Vol. 41, No. 2 + * (1992), pp. 458-478 + * Published by: Blackwell Publishing for the Royal Statistical Society + * Stable URL: http://www.jstor.org/stable/2347583 </pre></p> + * + * <p>This method for multiple regression forms the solution to the OLS problem + * by updating the QR decomposition as described by Gentleman.</p> + * + * @since 3.0 + */ +public class MillerUpdatingRegression implements UpdatingMultipleLinearRegression { + + /** number of variables in regression */ + private final int nvars; + /** diagonals of cross products matrix */ + private final double[] d; + /** the elements of the R`Y */ + private final double[] rhs; + /** the off diagonal portion of the R matrix */ + private final double[] r; + /** the tolerance for each of the variables */ + private final double[] tol; + /** residual sum of squares for all nested regressions */ + private final double[] rss; + /** order of the regressors */ + private final int[] vorder; + /** scratch space for tolerance calc */ + private final double[] work_tolset; + /** number of observations entered */ + private long nobs = 0; + /** sum of squared errors of largest regression */ + private double sserr = 0.0; + /** has rss been called? */ + private boolean rss_set = false; + /** has the tolerance setting method been called */ + private boolean tol_set = false; + /** flags for variables with linear dependency problems */ + private final boolean[] lindep; + /** singular x values */ + private final double[] x_sing; + /** workspace for singularity method */ + private final double[] work_sing; + /** summation of Y variable */ + private double sumy = 0.0; + /** summation of squared Y values */ + private double sumsqy = 0.0; + /** boolean flag whether a regression constant is added */ + private boolean hasIntercept; + /** zero tolerance */ + private final double epsilon; + /** + * Set the default constructor to private access + * to prevent inadvertent instantiation + */ + @SuppressWarnings("unused") + private MillerUpdatingRegression() { + this(-1, false, Double.NaN); + } + + /** + * This is the augmented constructor for the MillerUpdatingRegression class. + * + * @param numberOfVariables number of regressors to expect, not including constant + * @param includeConstant include a constant automatically + * @param errorTolerance zero tolerance, how machine zero is determined + * @throws ModelSpecificationException if {@code numberOfVariables is less than 1} + */ + public MillerUpdatingRegression(int numberOfVariables, boolean includeConstant, double errorTolerance) + throws ModelSpecificationException { + if (numberOfVariables < 1) { + throw new ModelSpecificationException(LocalizedFormats.NO_REGRESSORS); + } + if (includeConstant) { + this.nvars = numberOfVariables + 1; + } else { + this.nvars = numberOfVariables; + } + this.hasIntercept = includeConstant; + this.nobs = 0; + this.d = new double[this.nvars]; + this.rhs = new double[this.nvars]; + this.r = new double[this.nvars * (this.nvars - 1) / 2]; + this.tol = new double[this.nvars]; + this.rss = new double[this.nvars]; + this.vorder = new int[this.nvars]; + this.x_sing = new double[this.nvars]; + this.work_sing = new double[this.nvars]; + this.work_tolset = new double[this.nvars]; + this.lindep = new boolean[this.nvars]; + for (int i = 0; i < this.nvars; i++) { + vorder[i] = i; + } + if (errorTolerance > 0) { + this.epsilon = errorTolerance; + } else { + this.epsilon = -errorTolerance; + } + } + + /** + * Primary constructor for the MillerUpdatingRegression. + * + * @param numberOfVariables maximum number of potential regressors + * @param includeConstant include a constant automatically + * @throws ModelSpecificationException if {@code numberOfVariables is less than 1} + */ + public MillerUpdatingRegression(int numberOfVariables, boolean includeConstant) + throws ModelSpecificationException { + this(numberOfVariables, includeConstant, Precision.EPSILON); + } + + /** + * A getter method which determines whether a constant is included. + * @return true regression has an intercept, false no intercept + */ + public boolean hasIntercept() { + return this.hasIntercept; + } + + /** + * Gets the number of observations added to the regression model. + * @return number of observations + */ + public long getN() { + return this.nobs; + } + + /** + * Adds an observation to the regression model. + * @param x the array with regressor values + * @param y the value of dependent variable given these regressors + * @exception ModelSpecificationException if the length of {@code x} does not equal + * the number of independent variables in the model + */ + public void addObservation(final double[] x, final double y) + throws ModelSpecificationException { + + if ((!this.hasIntercept && x.length != nvars) || + (this.hasIntercept && x.length + 1 != nvars)) { + throw new ModelSpecificationException(LocalizedFormats.INVALID_REGRESSION_OBSERVATION, + x.length, nvars); + } + if (!this.hasIntercept) { + include(MathArrays.copyOf(x, x.length), 1.0, y); + } else { + final double[] tmp = new double[x.length + 1]; + System.arraycopy(x, 0, tmp, 1, x.length); + tmp[0] = 1.0; + include(tmp, 1.0, y); + } + ++nobs; + + } + + /** + * Adds multiple observations to the model. + * @param x observations on the regressors + * @param y observations on the regressand + * @throws ModelSpecificationException if {@code x} is not rectangular, does not match + * the length of {@code y} or does not contain sufficient data to estimate the model + */ + public void addObservations(double[][] x, double[] y) throws ModelSpecificationException { + if ((x == null) || (y == null) || (x.length != y.length)) { + throw new ModelSpecificationException( + LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, + (x == null) ? 0 : x.length, + (y == null) ? 0 : y.length); + } + if (x.length == 0) { // Must be no y data either + throw new ModelSpecificationException( + LocalizedFormats.NO_DATA); + } + if (x[0].length + 1 > x.length) { + throw new ModelSpecificationException( + LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + x.length, x[0].length); + } + for (int i = 0; i < x.length; i++) { + addObservation(x[i], y[i]); + } + } + + /** + * The include method is where the QR decomposition occurs. This statement forms all + * intermediate data which will be used for all derivative measures. + * According to the miller paper, note that in the original implementation the x vector + * is overwritten. In this implementation, the include method is passed a copy of the + * original data vector so that there is no contamination of the data. Additionally, + * this method differs slightly from Gentleman's method, in that the assumption is + * of dense design matrices, there is some advantage in using the original gentleman algorithm + * on sparse matrices. + * + * @param x observations on the regressors + * @param wi weight of the this observation (-1,1) + * @param yi observation on the regressand + */ + private void include(final double[] x, final double wi, final double yi) { + int nextr = 0; + double w = wi; + double y = yi; + double xi; + double di; + double wxi; + double dpi; + double xk; + double _w; + this.rss_set = false; + sumy = smartAdd(yi, sumy); + sumsqy = smartAdd(sumsqy, yi * yi); + for (int i = 0; i < x.length; i++) { + if (w == 0.0) { + return; + } + xi = x[i]; + + if (xi == 0.0) { + nextr += nvars - i - 1; + continue; + } + di = d[i]; + wxi = w * xi; + _w = w; + if (di != 0.0) { + dpi = smartAdd(di, wxi * xi); + final double tmp = wxi * xi / di; + if (FastMath.abs(tmp) > Precision.EPSILON) { + w = (di * w) / dpi; + } + } else { + dpi = wxi * xi; + w = 0.0; + } + d[i] = dpi; + for (int k = i + 1; k < nvars; k++) { + xk = x[k]; + x[k] = smartAdd(xk, -xi * r[nextr]); + if (di != 0.0) { + r[nextr] = smartAdd(di * r[nextr], (_w * xi) * xk) / dpi; + } else { + r[nextr] = xk / xi; + } + ++nextr; + } + xk = y; + y = smartAdd(xk, -xi * rhs[i]); + if (di != 0.0) { + rhs[i] = smartAdd(di * rhs[i], wxi * xk) / dpi; + } else { + rhs[i] = xk / xi; + } + } + sserr = smartAdd(sserr, w * y * y); + } + + /** + * Adds to number a and b such that the contamination due to + * numerical smallness of one addend does not corrupt the sum. + * @param a - an addend + * @param b - an addend + * @return the sum of the a and b + */ + private double smartAdd(double a, double b) { + final double _a = FastMath.abs(a); + final double _b = FastMath.abs(b); + if (_a > _b) { + final double eps = _a * Precision.EPSILON; + if (_b > eps) { + return a + b; + } + return a; + } else { + final double eps = _b * Precision.EPSILON; + if (_a > eps) { + return a + b; + } + return b; + } + } + + /** + * As the name suggests, clear wipes the internals and reorders everything in the + * canonical order. + */ + public void clear() { + Arrays.fill(this.d, 0.0); + Arrays.fill(this.rhs, 0.0); + Arrays.fill(this.r, 0.0); + Arrays.fill(this.tol, 0.0); + Arrays.fill(this.rss, 0.0); + Arrays.fill(this.work_tolset, 0.0); + Arrays.fill(this.work_sing, 0.0); + Arrays.fill(this.x_sing, 0.0); + Arrays.fill(this.lindep, false); + for (int i = 0; i < nvars; i++) { + this.vorder[i] = i; + } + this.nobs = 0; + this.sserr = 0.0; + this.sumy = 0.0; + this.sumsqy = 0.0; + this.rss_set = false; + this.tol_set = false; + } + + /** + * This sets up tolerances for singularity testing. + */ + private void tolset() { + int pos; + double total; + final double eps = this.epsilon; + for (int i = 0; i < nvars; i++) { + this.work_tolset[i] = FastMath.sqrt(d[i]); + } + tol[0] = eps * this.work_tolset[0]; + for (int col = 1; col < nvars; col++) { + pos = col - 1; + total = work_tolset[col]; + for (int row = 0; row < col; row++) { + total += FastMath.abs(r[pos]) * work_tolset[row]; + pos += nvars - row - 2; + } + tol[col] = eps * total; + } + tol_set = true; + } + + /** + * The regcf method conducts the linear regression and extracts the + * parameter vector. Notice that the algorithm can do subset regression + * with no alteration. + * + * @param nreq how many of the regressors to include (either in canonical + * order, or in the current reordered state) + * @return an array with the estimated slope coefficients + * @throws ModelSpecificationException if {@code nreq} is less than 1 + * or greater than the number of independent variables + */ + private double[] regcf(int nreq) throws ModelSpecificationException { + int nextr; + if (nreq < 1) { + throw new ModelSpecificationException(LocalizedFormats.NO_REGRESSORS); + } + if (nreq > this.nvars) { + throw new ModelSpecificationException( + LocalizedFormats.TOO_MANY_REGRESSORS, nreq, this.nvars); + } + if (!this.tol_set) { + tolset(); + } + final double[] ret = new double[nreq]; + boolean rankProblem = false; + for (int i = nreq - 1; i > -1; i--) { + if (FastMath.sqrt(d[i]) < tol[i]) { + ret[i] = 0.0; + d[i] = 0.0; + rankProblem = true; + } else { + ret[i] = rhs[i]; + nextr = i * (nvars + nvars - i - 1) / 2; + for (int j = i + 1; j < nreq; j++) { + ret[i] = smartAdd(ret[i], -r[nextr] * ret[j]); + ++nextr; + } + } + } + if (rankProblem) { + for (int i = 0; i < nreq; i++) { + if (this.lindep[i]) { + ret[i] = Double.NaN; + } + } + } + return ret; + } + + /** + * The method which checks for singularities and then eliminates the offending + * columns. + */ + private void singcheck() { + int pos; + for (int i = 0; i < nvars; i++) { + work_sing[i] = FastMath.sqrt(d[i]); + } + for (int col = 0; col < nvars; col++) { + // Set elements within R to zero if they are less than tol(col) in + // absolute value after being scaled by the square root of their row + // multiplier + final double temp = tol[col]; + pos = col - 1; + for (int row = 0; row < col - 1; row++) { + if (FastMath.abs(r[pos]) * work_sing[row] < temp) { + r[pos] = 0.0; + } + pos += nvars - row - 2; + } + // If diagonal element is near zero, set it to zero, set appropriate + // element of LINDEP, and use INCLUD to augment the projections in + // the lower rows of the orthogonalization. + lindep[col] = false; + if (work_sing[col] < temp) { + lindep[col] = true; + if (col < nvars - 1) { + Arrays.fill(x_sing, 0.0); + int _pi = col * (nvars + nvars - col - 1) / 2; + for (int _xi = col + 1; _xi < nvars; _xi++, _pi++) { + x_sing[_xi] = r[_pi]; + r[_pi] = 0.0; + } + final double y = rhs[col]; + final double weight = d[col]; + d[col] = 0.0; + rhs[col] = 0.0; + this.include(x_sing, weight, y); + } else { + sserr += d[col] * rhs[col] * rhs[col]; + } + } + } + } + + /** + * Calculates the sum of squared errors for the full regression + * and all subsets in the following manner: <pre> + * rss[] ={ + * ResidualSumOfSquares_allNvars, + * ResidualSumOfSquares_FirstNvars-1, + * ResidualSumOfSquares_FirstNvars-2, + * ..., ResidualSumOfSquares_FirstVariable} </pre> + */ + private void ss() { + double total = sserr; + rss[nvars - 1] = sserr; + for (int i = nvars - 1; i > 0; i--) { + total += d[i] * rhs[i] * rhs[i]; + rss[i - 1] = total; + } + rss_set = true; + } + + /** + * Calculates the cov matrix assuming only the first nreq variables are + * included in the calculation. The returned array contains a symmetric + * matrix stored in lower triangular form. The matrix will have + * ( nreq + 1 ) * nreq / 2 elements. For illustration <pre> + * cov = + * { + * cov_00, + * cov_10, cov_11, + * cov_20, cov_21, cov22, + * ... + * } </pre> + * + * @param nreq how many of the regressors to include (either in canonical + * order, or in the current reordered state) + * @return an array with the variance covariance of the included + * regressors in lower triangular form + */ + private double[] cov(int nreq) { + if (this.nobs <= nreq) { + return null; + } + double rnk = 0.0; + for (int i = 0; i < nreq; i++) { + if (!this.lindep[i]) { + rnk += 1.0; + } + } + final double var = rss[nreq - 1] / (nobs - rnk); + final double[] rinv = new double[nreq * (nreq - 1) / 2]; + inverse(rinv, nreq); + final double[] covmat = new double[nreq * (nreq + 1) / 2]; + Arrays.fill(covmat, Double.NaN); + int pos2; + int pos1; + int start = 0; + double total = 0; + for (int row = 0; row < nreq; row++) { + pos2 = start; + if (!this.lindep[row]) { + for (int col = row; col < nreq; col++) { + if (!this.lindep[col]) { + pos1 = start + col - row; + if (row == col) { + total = 1.0 / d[col]; + } else { + total = rinv[pos1 - 1] / d[col]; + } + for (int k = col + 1; k < nreq; k++) { + if (!this.lindep[k]) { + total += rinv[pos1] * rinv[pos2] / d[k]; + } + ++pos1; + ++pos2; + } + covmat[ (col + 1) * col / 2 + row] = total * var; + } else { + pos2 += nreq - col - 1; + } + } + } + start += nreq - row - 1; + } + return covmat; + } + + /** + * This internal method calculates the inverse of the upper-triangular portion + * of the R matrix. + * @param rinv the storage for the inverse of r + * @param nreq how many of the regressors to include (either in canonical + * order, or in the current reordered state) + */ + private void inverse(double[] rinv, int nreq) { + int pos = nreq * (nreq - 1) / 2 - 1; + int pos1 = -1; + int pos2 = -1; + double total = 0.0; + Arrays.fill(rinv, Double.NaN); + for (int row = nreq - 1; row > 0; --row) { + if (!this.lindep[row]) { + final int start = (row - 1) * (nvars + nvars - row) / 2; + for (int col = nreq; col > row; --col) { + pos1 = start; + pos2 = pos; + total = 0.0; + for (int k = row; k < col - 1; k++) { + pos2 += nreq - k - 1; + if (!this.lindep[k]) { + total += -r[pos1] * rinv[pos2]; + } + ++pos1; + } + rinv[pos] = total - r[pos1]; + --pos; + } + } else { + pos -= nreq - row; + } + } + } + + /** + * In the original algorithm only the partial correlations of the regressors + * is returned to the user. In this implementation, we have <pre> + * corr = + * { + * corrxx - lower triangular + * corrxy - bottom row of the matrix + * } + * Replaces subroutines PCORR and COR of: + * ALGORITHM AS274 APPL. STATIST. (1992) VOL.41, NO. 2 </pre> + * + * <p>Calculate partial correlations after the variables in rows + * 1, 2, ..., IN have been forced into the regression. + * If IN = 1, and the first row of R represents a constant in the + * model, then the usual simple correlations are returned.</p> + * + * <p>If IN = 0, the value returned in array CORMAT for the correlation + * of variables Xi & Xj is: <pre> + * sum ( Xi.Xj ) / Sqrt ( sum (Xi^2) . sum (Xj^2) )</pre></p> + * + * <p>On return, array CORMAT contains the upper triangle of the matrix of + * partial correlations stored by rows, excluding the 1's on the diagonal. + * e.g. if IN = 2, the consecutive elements returned are: + * (3,4) (3,5) ... (3,ncol), (4,5) (4,6) ... (4,ncol), etc. + * Array YCORR stores the partial correlations with the Y-variable + * starting with YCORR(IN+1) = partial correlation with the variable in + * position (IN+1). </p> + * + * @param in how many of the regressors to include (either in canonical + * order, or in the current reordered state) + * @return an array with the partial correlations of the remainder of + * regressors with each other and the regressand, in lower triangular form + */ + public double[] getPartialCorrelations(int in) { + final double[] output = new double[(nvars - in + 1) * (nvars - in) / 2]; + int pos; + int pos1; + int pos2; + final int rms_off = -in; + final int wrk_off = -(in + 1); + final double[] rms = new double[nvars - in]; + final double[] work = new double[nvars - in - 1]; + double sumxx; + double sumxy; + double sumyy; + final int offXX = (nvars - in) * (nvars - in - 1) / 2; + if (in < -1 || in >= nvars) { + return null; + } + final int nvm = nvars - 1; + final int base_pos = r.length - (nvm - in) * (nvm - in + 1) / 2; + if (d[in] > 0.0) { + rms[in + rms_off] = 1.0 / FastMath.sqrt(d[in]); + } + for (int col = in + 1; col < nvars; col++) { + pos = base_pos + col - 1 - in; + sumxx = d[col]; + for (int row = in; row < col; row++) { + sumxx += d[row] * r[pos] * r[pos]; + pos += nvars - row - 2; + } + if (sumxx > 0.0) { + rms[col + rms_off] = 1.0 / FastMath.sqrt(sumxx); + } else { + rms[col + rms_off] = 0.0; + } + } + sumyy = sserr; + for (int row = in; row < nvars; row++) { + sumyy += d[row] * rhs[row] * rhs[row]; + } + if (sumyy > 0.0) { + sumyy = 1.0 / FastMath.sqrt(sumyy); + } + pos = 0; + for (int col1 = in; col1 < nvars; col1++) { + sumxy = 0.0; + Arrays.fill(work, 0.0); + pos1 = base_pos + col1 - in - 1; + for (int row = in; row < col1; row++) { + pos2 = pos1 + 1; + for (int col2 = col1 + 1; col2 < nvars; col2++) { + work[col2 + wrk_off] += d[row] * r[pos1] * r[pos2]; + pos2++; + } + sumxy += d[row] * r[pos1] * rhs[row]; + pos1 += nvars - row - 2; + } + pos2 = pos1 + 1; + for (int col2 = col1 + 1; col2 < nvars; col2++) { + work[col2 + wrk_off] += d[col1] * r[pos2]; + ++pos2; + output[ (col2 - 1 - in) * (col2 - in) / 2 + col1 - in] = + work[col2 + wrk_off] * rms[col1 + rms_off] * rms[col2 + rms_off]; + ++pos; + } + sumxy += d[col1] * rhs[col1]; + output[col1 + rms_off + offXX] = sumxy * rms[col1 + rms_off] * sumyy; + } + + return output; + } + + /** + * ALGORITHM AS274 APPL. STATIST. (1992) VOL.41, NO. 2. + * Move variable from position FROM to position TO in an + * orthogonal reduction produced by AS75.1. + * + * @param from initial position + * @param to destination + */ + private void vmove(int from, int to) { + double d1; + double d2; + double X; + double d1new; + double d2new; + double cbar; + double sbar; + double Y; + int first; + int inc; + int m1; + int m2; + int mp1; + int pos; + boolean bSkipTo40 = false; + if (from == to) { + return; + } + if (!this.rss_set) { + ss(); + } + int count = 0; + if (from < to) { + first = from; + inc = 1; + count = to - from; + } else { + first = from - 1; + inc = -1; + count = from - to; + } + + int m = first; + int idx = 0; + while (idx < count) { + m1 = m * (nvars + nvars - m - 1) / 2; + m2 = m1 + nvars - m - 1; + mp1 = m + 1; + + d1 = d[m]; + d2 = d[mp1]; + // Special cases. + if (d1 > this.epsilon || d2 > this.epsilon) { + X = r[m1]; + if (FastMath.abs(X) * FastMath.sqrt(d1) < tol[mp1]) { + X = 0.0; + } + if (d1 < this.epsilon || FastMath.abs(X) < this.epsilon) { + d[m] = d2; + d[mp1] = d1; + r[m1] = 0.0; + for (int col = m + 2; col < nvars; col++) { + ++m1; + X = r[m1]; + r[m1] = r[m2]; + r[m2] = X; + ++m2; + } + X = rhs[m]; + rhs[m] = rhs[mp1]; + rhs[mp1] = X; + bSkipTo40 = true; + //break; + } else if (d2 < this.epsilon) { + d[m] = d1 * X * X; + r[m1] = 1.0 / X; + for (int _i = m1 + 1; _i < m1 + nvars - m - 1; _i++) { + r[_i] /= X; + } + rhs[m] /= X; + bSkipTo40 = true; + //break; + } + if (!bSkipTo40) { + d1new = d2 + d1 * X * X; + cbar = d2 / d1new; + sbar = X * d1 / d1new; + d2new = d1 * cbar; + d[m] = d1new; + d[mp1] = d2new; + r[m1] = sbar; + for (int col = m + 2; col < nvars; col++) { + ++m1; + Y = r[m1]; + r[m1] = cbar * r[m2] + sbar * Y; + r[m2] = Y - X * r[m2]; + ++m2; + } + Y = rhs[m]; + rhs[m] = cbar * rhs[mp1] + sbar * Y; + rhs[mp1] = Y - X * rhs[mp1]; + } + } + if (m > 0) { + pos = m; + for (int row = 0; row < m; row++) { + X = r[pos]; + r[pos] = r[pos - 1]; + r[pos - 1] = X; + pos += nvars - row - 2; + } + } + // Adjust variable order (VORDER), the tolerances (TOL) and + // the vector of residual sums of squares (RSS). + m1 = vorder[m]; + vorder[m] = vorder[mp1]; + vorder[mp1] = m1; + X = tol[m]; + tol[m] = tol[mp1]; + tol[mp1] = X; + rss[m] = rss[mp1] + d[mp1] * rhs[mp1] * rhs[mp1]; + + m += inc; + ++idx; + } + } + + /** + * ALGORITHM AS274 APPL. STATIST. (1992) VOL.41, NO. 2 + * + * <p> Re-order the variables in an orthogonal reduction produced by + * AS75.1 so that the N variables in LIST start at position POS1, + * though will not necessarily be in the same order as in LIST. + * Any variables in VORDER before position POS1 are not moved. + * Auxiliary routine called: VMOVE. </p> + * + * <p>This internal method reorders the regressors.</p> + * + * @param list the regressors to move + * @param pos1 where the list will be placed + * @return -1 error, 0 everything ok + */ + private int reorderRegressors(int[] list, int pos1) { + int next; + int i; + int l; + if (list.length < 1 || list.length > nvars + 1 - pos1) { + return -1; + } + next = pos1; + i = pos1; + while (i < nvars) { + l = vorder[i]; + for (int j = 0; j < list.length; j++) { + if (l == list[j] && i > next) { + this.vmove(i, next); + ++next; + if (next >= list.length + pos1) { + return 0; + } else { + break; + } + } + } + ++i; + } + return 0; + } + + /** + * Gets the diagonal of the Hat matrix also known as the leverage matrix. + * + * @param row_data returns the diagonal of the hat matrix for this observation + * @return the diagonal element of the hatmatrix + */ + public double getDiagonalOfHatMatrix(double[] row_data) { + double[] wk = new double[this.nvars]; + int pos; + double total; + + if (row_data.length > nvars) { + return Double.NaN; + } + double[] xrow; + if (this.hasIntercept) { + xrow = new double[row_data.length + 1]; + xrow[0] = 1.0; + System.arraycopy(row_data, 0, xrow, 1, row_data.length); + } else { + xrow = row_data; + } + double hii = 0.0; + for (int col = 0; col < xrow.length; col++) { + if (FastMath.sqrt(d[col]) < tol[col]) { + wk[col] = 0.0; + } else { + pos = col - 1; + total = xrow[col]; + for (int row = 0; row < col; row++) { + total = smartAdd(total, -wk[row] * r[pos]); + pos += nvars - row - 2; + } + wk[col] = total; + hii = smartAdd(hii, (total * total) / d[col]); + } + } + return hii; + } + + /** + * Gets the order of the regressors, useful if some type of reordering + * has been called. Calling regress with int[]{} args will trigger + * a reordering. + * + * @return int[] with the current order of the regressors + */ + public int[] getOrderOfRegressors(){ + return MathArrays.copyOf(vorder); + } + + /** + * Conducts a regression on the data in the model, using all regressors. + * + * @return RegressionResults the structure holding all regression results + * @exception ModelSpecificationException - thrown if number of observations is + * less than the number of variables + */ + public RegressionResults regress() throws ModelSpecificationException { + return regress(this.nvars); + } + + /** + * Conducts a regression on the data in the model, using a subset of regressors. + * + * @param numberOfRegressors many of the regressors to include (either in canonical + * order, or in the current reordered state) + * @return RegressionResults the structure holding all regression results + * @exception ModelSpecificationException - thrown if number of observations is + * less than the number of variables or number of regressors requested + * is greater than the regressors in the model + */ + public RegressionResults regress(int numberOfRegressors) throws ModelSpecificationException { + if (this.nobs <= numberOfRegressors) { + throw new ModelSpecificationException( + LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + this.nobs, numberOfRegressors); + } + if( numberOfRegressors > this.nvars ){ + throw new ModelSpecificationException( + LocalizedFormats.TOO_MANY_REGRESSORS, numberOfRegressors, this.nvars); + } + + tolset(); + singcheck(); + + double[] beta = this.regcf(numberOfRegressors); + + ss(); + + double[] cov = this.cov(numberOfRegressors); + + int rnk = 0; + for (int i = 0; i < this.lindep.length; i++) { + if (!this.lindep[i]) { + ++rnk; + } + } + + boolean needsReorder = false; + for (int i = 0; i < numberOfRegressors; i++) { + if (this.vorder[i] != i) { + needsReorder = true; + break; + } + } + if (!needsReorder) { + return new RegressionResults( + beta, new double[][]{cov}, true, this.nobs, rnk, + this.sumy, this.sumsqy, this.sserr, this.hasIntercept, false); + } else { + double[] betaNew = new double[beta.length]; + double[] covNew = new double[cov.length]; + + int[] newIndices = new int[beta.length]; + for (int i = 0; i < nvars; i++) { + for (int j = 0; j < numberOfRegressors; j++) { + if (this.vorder[j] == i) { + betaNew[i] = beta[ j]; + newIndices[i] = j; + } + } + } + + int idx1 = 0; + int idx2; + int _i; + int _j; + for (int i = 0; i < beta.length; i++) { + _i = newIndices[i]; + for (int j = 0; j <= i; j++, idx1++) { + _j = newIndices[j]; + if (_i > _j) { + idx2 = _i * (_i + 1) / 2 + _j; + } else { + idx2 = _j * (_j + 1) / 2 + _i; + } + covNew[idx1] = cov[idx2]; + } + } + return new RegressionResults( + betaNew, new double[][]{covNew}, true, this.nobs, rnk, + this.sumy, this.sumsqy, this.sserr, this.hasIntercept, false); + } + } + + /** + * Conducts a regression on the data in the model, using regressors in array + * Calling this method will change the internal order of the regressors + * and care is required in interpreting the hatmatrix. + * + * @param variablesToInclude array of variables to include in regression + * @return RegressionResults the structure holding all regression results + * @exception ModelSpecificationException - thrown if number of observations is + * less than the number of variables, the number of regressors requested + * is greater than the regressors in the model or a regressor index in + * regressor array does not exist + */ + public RegressionResults regress(int[] variablesToInclude) throws ModelSpecificationException { + if (variablesToInclude.length > this.nvars) { + throw new ModelSpecificationException( + LocalizedFormats.TOO_MANY_REGRESSORS, variablesToInclude.length, this.nvars); + } + if (this.nobs <= this.nvars) { + throw new ModelSpecificationException( + LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + this.nobs, this.nvars); + } + Arrays.sort(variablesToInclude); + int iExclude = 0; + for (int i = 0; i < variablesToInclude.length; i++) { + if (i >= this.nvars) { + throw new ModelSpecificationException( + LocalizedFormats.INDEX_LARGER_THAN_MAX, i, this.nvars); + } + if (i > 0 && variablesToInclude[i] == variablesToInclude[i - 1]) { + variablesToInclude[i] = -1; + ++iExclude; + } + } + int[] series; + if (iExclude > 0) { + int j = 0; + series = new int[variablesToInclude.length - iExclude]; + for (int i = 0; i < variablesToInclude.length; i++) { + if (variablesToInclude[i] > -1) { + series[j] = variablesToInclude[i]; + ++j; + } + } + } else { + series = variablesToInclude; + } + + reorderRegressors(series, 0); + tolset(); + singcheck(); + + double[] beta = this.regcf(series.length); + + ss(); + + double[] cov = this.cov(series.length); + + int rnk = 0; + for (int i = 0; i < this.lindep.length; i++) { + if (!this.lindep[i]) { + ++rnk; + } + } + + boolean needsReorder = false; + for (int i = 0; i < this.nvars; i++) { + if (this.vorder[i] != series[i]) { + needsReorder = true; + break; + } + } + if (!needsReorder) { + return new RegressionResults( + beta, new double[][]{cov}, true, this.nobs, rnk, + this.sumy, this.sumsqy, this.sserr, this.hasIntercept, false); + } else { + double[] betaNew = new double[beta.length]; + int[] newIndices = new int[beta.length]; + for (int i = 0; i < series.length; i++) { + for (int j = 0; j < this.vorder.length; j++) { + if (this.vorder[j] == series[i]) { + betaNew[i] = beta[ j]; + newIndices[i] = j; + } + } + } + double[] covNew = new double[cov.length]; + int idx1 = 0; + int idx2; + int _i; + int _j; + for (int i = 0; i < beta.length; i++) { + _i = newIndices[i]; + for (int j = 0; j <= i; j++, idx1++) { + _j = newIndices[j]; + if (_i > _j) { + idx2 = _i * (_i + 1) / 2 + _j; + } else { + idx2 = _j * (_j + 1) / 2 + _i; + } + covNew[idx1] = cov[idx2]; + } + } + return new RegressionResults( + betaNew, new double[][]{covNew}, true, this.nobs, rnk, + this.sumy, this.sumsqy, this.sserr, this.hasIntercept, false); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/ModelSpecificationException.java b/src/main/java/org/apache/commons/math3/stat/regression/ModelSpecificationException.java new file mode 100644 index 0000000..f3804db --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/ModelSpecificationException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.util.Localizable; + +/** + * Exception thrown when a regression model is not correctly specified. + * + * @since 3.0 + */ +public class ModelSpecificationException extends MathIllegalArgumentException { + /** Serializable version Id. */ + private static final long serialVersionUID = 4206514456095401070L; + + /** + * @param pattern message pattern describing the specification error. + * + * @param args arguments. + */ + public ModelSpecificationException(Localizable pattern, + Object ... args) { + super(pattern, args); + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/MultipleLinearRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/MultipleLinearRegression.java new file mode 100644 index 0000000..866214f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/MultipleLinearRegression.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +/** + * The multiple linear regression can be represented in matrix-notation. + * <pre> + * y=X*b+u + * </pre> + * where y is an <code>n-vector</code> <b>regressand</b>, X is a <code>[n,k]</code> matrix whose <code>k</code> columns are called + * <b>regressors</b>, b is <code>k-vector</code> of <b>regression parameters</b> and <code>u</code> is an <code>n-vector</code> + * of <b>error terms</b> or <b>residuals</b>. + * + * The notation is quite standard in literature, + * cf eg <a href="http://www.econ.queensu.ca/ETM">Davidson and MacKinnon, Econometrics Theory and Methods, 2004</a>. + * @since 2.0 + */ +public interface MultipleLinearRegression { + + /** + * Estimates the regression parameters b. + * + * @return The [k,1] array representing b + */ + double[] estimateRegressionParameters(); + + /** + * Estimates the variance of the regression parameters, ie Var(b). + * + * @return The [k,k] array representing the variance of b + */ + double[][] estimateRegressionParametersVariance(); + + /** + * Estimates the residuals, ie u = y - X*b. + * + * @return The [n,1] array representing the residuals + */ + double[] estimateResiduals(); + + /** + * Returns the variance of the regressand, ie Var(y). + * + * @return The double representing the variance of y + */ + double estimateRegressandVariance(); + + /** + * Returns the standard errors of the regression parameters. + * + * @return standard errors of estimated regression parameters + */ + double[] estimateRegressionParametersStandardErrors(); + +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.java new file mode 100644 index 0000000..7fff940 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.LUDecomposition; +import org.apache.commons.math3.linear.QRDecomposition; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.stat.StatUtils; +import org.apache.commons.math3.stat.descriptive.moment.SecondMoment; + +/** + * <p>Implements ordinary least squares (OLS) to estimate the parameters of a + * multiple linear regression model.</p> + * + * <p>The regression coefficients, <code>b</code>, satisfy the normal equations: + * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p> + * + * <p>To solve the normal equations, this implementation uses QR decomposition + * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the + * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i> + * has rows corresponding to sample observations and columns corresponding to independent + * variables. When the model is estimated using an intercept term (i.e. when + * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code> + * matrix includes an initial column identically equal to 1. We solve the normal equations + * as follows: + * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y + * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y + * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y + * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y + * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y + * R b = Q<sup>T</sup> y </code></pre></p> + * + * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p> + * + * @since 2.0 + */ +public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { + + /** Cached QR decomposition of X matrix */ + private QRDecomposition qr = null; + + /** Singularity threshold for QR decomposition */ + private final double threshold; + + /** + * Create an empty OLSMultipleLinearRegression instance. + */ + public OLSMultipleLinearRegression() { + this(0d); + } + + /** + * Create an empty OLSMultipleLinearRegression instance, using the given + * singularity threshold for the QR decomposition. + * + * @param threshold the singularity threshold + * @since 3.3 + */ + public OLSMultipleLinearRegression(final double threshold) { + this.threshold = threshold; + } + + /** + * Loads model x and y sample data, overriding any previous sample. + * + * Computes and caches QR decomposition of the X matrix. + * @param y the [n,1] array representing the y sample + * @param x the [n,k] array representing the x sample + * @throws MathIllegalArgumentException if the x and y array data are not + * compatible for the regression + */ + public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException { + validateSampleData(x, y); + newYSampleData(y); + newXSampleData(x); + } + + /** + * {@inheritDoc} + * <p>This implementation computes and caches the QR decomposition of the X matrix.</p> + */ + @Override + public void newSampleData(double[] data, int nobs, int nvars) { + super.newSampleData(data, nobs, nvars); + qr = new QRDecomposition(getX(), threshold); + } + + /** + * <p>Compute the "hat" matrix. + * </p> + * <p>The hat matrix is defined in terms of the design matrix X + * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> + * </p> + * <p>The implementation here uses the QR decomposition to compute the + * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the + * p-dimensional identity matrix augmented by 0's. This computational + * formula is from "The Hat Matrix in Regression and ANOVA", + * David C. Hoaglin and Roy E. Welsch, + * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. + * </p> + * <p>Data for the model must have been successfully loaded using one of + * the {@code newSampleData} methods before invoking this method; otherwise + * a {@code NullPointerException} will be thrown.</p> + * + * @return the hat matrix + * @throws NullPointerException unless method {@code newSampleData} has been + * called beforehand. + */ + public RealMatrix calculateHat() { + // Create augmented identity matrix + RealMatrix Q = qr.getQ(); + final int p = qr.getR().getColumnDimension(); + final int n = Q.getColumnDimension(); + // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3 + Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); + double[][] augIData = augI.getDataRef(); + for (int i = 0; i < n; i++) { + for (int j =0; j < n; j++) { + if (i == j && i < p) { + augIData[i][j] = 1d; + } else { + augIData[i][j] = 0d; + } + } + } + + // Compute and return Hat matrix + // No DME advertised - args valid if we get here + return Q.multiply(augI).multiply(Q.transpose()); + } + + /** + * <p>Returns the sum of squared deviations of Y from its mean.</p> + * + * <p>If the model has no intercept term, <code>0</code> is used for the + * mean of Y - i.e., what is returned is the sum of the squared Y values.</p> + * + * <p>The value returned by this method is the SSTO value used in + * the {@link #calculateRSquared() R-squared} computation.</p> + * + * @return SSTO - the total sum of squares + * @throws NullPointerException if the sample has not been set + * @see #isNoIntercept() + * @since 2.2 + */ + public double calculateTotalSumOfSquares() { + if (isNoIntercept()) { + return StatUtils.sumSq(getY().toArray()); + } else { + return new SecondMoment().evaluate(getY().toArray()); + } + } + + /** + * Returns the sum of squared residuals. + * + * @return residual sum of squares + * @since 2.2 + * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular + * @throws NullPointerException if the data for the model have not been loaded + */ + public double calculateResidualSumOfSquares() { + final RealVector residuals = calculateResiduals(); + // No advertised DME, args are valid + return residuals.dotProduct(residuals); + } + + /** + * Returns the R-Squared statistic, defined by the formula <pre> + * R<sup>2</sup> = 1 - SSR / SSTO + * </pre> + * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} + * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} + * + * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> + * + * @return R-square statistic + * @throws NullPointerException if the sample has not been set + * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular + * @since 2.2 + */ + public double calculateRSquared() { + return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); + } + + /** + * <p>Returns the adjusted R-squared statistic, defined by the formula <pre> + * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] + * </pre> + * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, + * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number + * of observations and p is the number of parameters estimated (including the intercept).</p> + * + * <p>If the regression is estimated without an intercept term, what is returned is <pre> + * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code> + * </pre></p> + * + * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> + * + * @return adjusted R-Squared statistic + * @throws NullPointerException if the sample has not been set + * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular + * @see #isNoIntercept() + * @since 2.2 + */ + public double calculateAdjustedRSquared() { + final double n = getX().getRowDimension(); + if (isNoIntercept()) { + return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension())); + } else { + return 1 - (calculateResidualSumOfSquares() * (n - 1)) / + (calculateTotalSumOfSquares() * (n - getX().getColumnDimension())); + } + } + + /** + * {@inheritDoc} + * <p>This implementation computes and caches the QR decomposition of the X matrix + * once it is successfully loaded.</p> + */ + @Override + protected void newXSampleData(double[][] x) { + super.newXSampleData(x); + qr = new QRDecomposition(getX(), threshold); + } + + /** + * Calculates the regression coefficients using OLS. + * + * <p>Data for the model must have been successfully loaded using one of + * the {@code newSampleData} methods before invoking this method; otherwise + * a {@code NullPointerException} will be thrown.</p> + * + * @return beta + * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular + * @throws NullPointerException if the data for the model have not been loaded + */ + @Override + protected RealVector calculateBeta() { + return qr.getSolver().solve(getY()); + } + + /** + * <p>Calculates the variance-covariance matrix of the regression parameters. + * </p> + * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> + * </p> + * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> + * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of + * R included, where p = the length of the beta vector.</p> + * + * <p>Data for the model must have been successfully loaded using one of + * the {@code newSampleData} methods before invoking this method; otherwise + * a {@code NullPointerException} will be thrown.</p> + * + * @return The beta variance-covariance matrix + * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular + * @throws NullPointerException if the data for the model have not been loaded + */ + @Override + protected RealMatrix calculateBetaVariance() { + int p = getX().getColumnDimension(); + RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); + RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse(); + return Rinv.multiply(Rinv.transpose()); + } + +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/RegressionResults.java b/src/main/java/org/apache/commons/math3/stat/regression/RegressionResults.java new file mode 100644 index 0000000..70faeac --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/RegressionResults.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import java.io.Serializable; +import java.util.Arrays; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.exception.OutOfRangeException; + +/** + * Results of a Multiple Linear Regression model fit. + * + * @since 3.0 + */ +public class RegressionResults implements Serializable { + + /** INDEX of Sum of Squared Errors */ + private static final int SSE_IDX = 0; + /** INDEX of Sum of Squares of Model */ + private static final int SST_IDX = 1; + /** INDEX of R-Squared of regression */ + private static final int RSQ_IDX = 2; + /** INDEX of Mean Squared Error */ + private static final int MSE_IDX = 3; + /** INDEX of Adjusted R Squared */ + private static final int ADJRSQ_IDX = 4; + /** UID */ + private static final long serialVersionUID = 1l; + /** regression slope parameters */ + private final double[] parameters; + /** variance covariance matrix of parameters */ + private final double[][] varCovData; + /** boolean flag for variance covariance matrix in symm compressed storage */ + private final boolean isSymmetricVCD; + /** rank of the solution */ + @SuppressWarnings("unused") + private final int rank; + /** number of observations on which results are based */ + private final long nobs; + /** boolean flag indicator of whether a constant was included*/ + private final boolean containsConstant; + /** array storing global results, SSE, MSE, RSQ, adjRSQ */ + private final double[] globalFitInfo; + + /** + * Set the default constructor to private access + * to prevent inadvertent instantiation + */ + @SuppressWarnings("unused") + private RegressionResults() { + this.parameters = null; + this.varCovData = null; + this.rank = -1; + this.nobs = -1; + this.containsConstant = false; + this.isSymmetricVCD = false; + this.globalFitInfo = null; + } + + /** + * Constructor for Regression Results. + * + * @param parameters a double array with the regression slope estimates + * @param varcov the variance covariance matrix, stored either in a square matrix + * or as a compressed + * @param isSymmetricCompressed a flag which denotes that the variance covariance + * matrix is in symmetric compressed format + * @param nobs the number of observations of the regression estimation + * @param rank the number of independent variables in the regression + * @param sumy the sum of the independent variable + * @param sumysq the sum of the squared independent variable + * @param sse sum of squared errors + * @param containsConstant true model has constant, false model does not have constant + * @param copyData if true a deep copy of all input data is made, if false only references + * are copied and the RegressionResults become mutable + */ + public RegressionResults( + final double[] parameters, final double[][] varcov, + final boolean isSymmetricCompressed, + final long nobs, final int rank, + final double sumy, final double sumysq, final double sse, + final boolean containsConstant, + final boolean copyData) { + if (copyData) { + this.parameters = MathArrays.copyOf(parameters); + this.varCovData = new double[varcov.length][]; + for (int i = 0; i < varcov.length; i++) { + this.varCovData[i] = MathArrays.copyOf(varcov[i]); + } + } else { + this.parameters = parameters; + this.varCovData = varcov; + } + this.isSymmetricVCD = isSymmetricCompressed; + this.nobs = nobs; + this.rank = rank; + this.containsConstant = containsConstant; + this.globalFitInfo = new double[5]; + Arrays.fill(this.globalFitInfo, Double.NaN); + + if (rank > 0) { + this.globalFitInfo[SST_IDX] = containsConstant ? + (sumysq - sumy * sumy / nobs) : sumysq; + } + + this.globalFitInfo[SSE_IDX] = sse; + this.globalFitInfo[MSE_IDX] = this.globalFitInfo[SSE_IDX] / + (nobs - rank); + this.globalFitInfo[RSQ_IDX] = 1.0 - + this.globalFitInfo[SSE_IDX] / + this.globalFitInfo[SST_IDX]; + + if (!containsConstant) { + this.globalFitInfo[ADJRSQ_IDX] = 1.0- + (1.0 - this.globalFitInfo[RSQ_IDX]) * + ( (double) nobs / ( (double) (nobs - rank))); + } else { + this.globalFitInfo[ADJRSQ_IDX] = 1.0 - (sse * (nobs - 1.0)) / + (globalFitInfo[SST_IDX] * (nobs - rank)); + } + } + + /** + * <p>Returns the parameter estimate for the regressor at the given index.</p> + * + * <p>A redundant regressor will have its redundancy flag set, as well as + * a parameters estimated equal to {@code Double.NaN}</p> + * + * @param index Index. + * @return the parameters estimated for regressor at index. + * @throws OutOfRangeException if {@code index} is not in the interval + * {@code [0, number of parameters)}. + */ + public double getParameterEstimate(int index) throws OutOfRangeException { + if (parameters == null) { + return Double.NaN; + } + if (index < 0 || index >= this.parameters.length) { + throw new OutOfRangeException(index, 0, this.parameters.length - 1); + } + return this.parameters[index]; + } + + /** + * <p>Returns a copy of the regression parameters estimates.</p> + * + * <p>The parameter estimates are returned in the natural order of the data.</p> + * + * <p>A redundant regressor will have its redundancy flag set, as will + * a parameter estimate equal to {@code Double.NaN}.</p> + * + * @return array of parameter estimates, null if no estimation occurred + */ + public double[] getParameterEstimates() { + if (this.parameters == null) { + return null; + } + return MathArrays.copyOf(parameters); + } + + /** + * Returns the <a href="http://www.xycoon.com/standerrorb(1).htm">standard + * error of the parameter estimate at index</a>, + * usually denoted s(b<sub>index</sub>). + * + * @param index Index. + * @return the standard errors associated with parameters estimated at index. + * @throws OutOfRangeException if {@code index} is not in the interval + * {@code [0, number of parameters)}. + */ + public double getStdErrorOfEstimate(int index) throws OutOfRangeException { + if (parameters == null) { + return Double.NaN; + } + if (index < 0 || index >= this.parameters.length) { + throw new OutOfRangeException(index, 0, this.parameters.length - 1); + } + double var = this.getVcvElement(index, index); + if (!Double.isNaN(var) && var > Double.MIN_VALUE) { + return FastMath.sqrt(var); + } + return Double.NaN; + } + + /** + * <p>Returns the <a href="http://www.xycoon.com/standerrorb(1).htm">standard + * error of the parameter estimates</a>, + * usually denoted s(b<sub>i</sub>).</p> + * + * <p>If there are problems with an ill conditioned design matrix then the regressor + * which is redundant will be assigned <code>Double.NaN</code>. </p> + * + * @return an array standard errors associated with parameters estimates, + * null if no estimation occurred + */ + public double[] getStdErrorOfEstimates() { + if (parameters == null) { + return null; + } + double[] se = new double[this.parameters.length]; + for (int i = 0; i < this.parameters.length; i++) { + double var = this.getVcvElement(i, i); + if (!Double.isNaN(var) && var > Double.MIN_VALUE) { + se[i] = FastMath.sqrt(var); + continue; + } + se[i] = Double.NaN; + } + return se; + } + + /** + * <p>Returns the covariance between regression parameters i and j.</p> + * + * <p>If there are problems with an ill conditioned design matrix then the covariance + * which involves redundant columns will be assigned {@code Double.NaN}. </p> + * + * @param i {@code i}th regression parameter. + * @param j {@code j}th regression parameter. + * @return the covariance of the parameter estimates. + * @throws OutOfRangeException if {@code i} or {@code j} is not in the + * interval {@code [0, number of parameters)}. + */ + public double getCovarianceOfParameters(int i, int j) throws OutOfRangeException { + if (parameters == null) { + return Double.NaN; + } + if (i < 0 || i >= this.parameters.length) { + throw new OutOfRangeException(i, 0, this.parameters.length - 1); + } + if (j < 0 || j >= this.parameters.length) { + throw new OutOfRangeException(j, 0, this.parameters.length - 1); + } + return this.getVcvElement(i, j); + } + + /** + * <p>Returns the number of parameters estimated in the model.</p> + * + * <p>This is the maximum number of regressors, some techniques may drop + * redundant parameters</p> + * + * @return number of regressors, -1 if not estimated + */ + public int getNumberOfParameters() { + if (this.parameters == null) { + return -1; + } + return this.parameters.length; + } + + /** + * Returns the number of observations added to the regression model. + * + * @return Number of observations, -1 if an error condition prevents estimation + */ + public long getN() { + return this.nobs; + } + + /** + * <p>Returns the sum of squared deviations of the y values about their mean.</p> + * + * <p>This is defined as SSTO + * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a>.</p> + * + * <p>If {@code n < 2}, this returns {@code Double.NaN}.</p> + * + * @return sum of squared deviations of y values + */ + public double getTotalSumSquares() { + return this.globalFitInfo[SST_IDX]; + } + + /** + * <p>Returns the sum of squared deviations of the predicted y values about + * their mean (which equals the mean of y).</p> + * + * <p>This is usually abbreviated SSR or SSM. It is defined as SSM + * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a></p> + * + * <p><strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double.NaN</code> is + * returned. + * </li></ul></p> + * + * @return sum of squared deviations of predicted y values + */ + public double getRegressionSumSquares() { + return this.globalFitInfo[SST_IDX] - this.globalFitInfo[SSE_IDX]; + } + + /** + * <p>Returns the <a href="http://www.xycoon.com/SumOfSquares.htm"> + * sum of squared errors</a> (SSE) associated with the regression + * model.</p> + * + * <p>The return value is constrained to be non-negative - i.e., if due to + * rounding errors the computational formula returns a negative result, + * 0 is returned.</p> + * + * <p><strong>Preconditions</strong>: <ul> + * <li>numberOfParameters data pairs + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @return sum of squared errors associated with the regression model + */ + public double getErrorSumSquares() { + return this.globalFitInfo[ SSE_IDX]; + } + + /** + * <p>Returns the sum of squared errors divided by the degrees of freedom, + * usually abbreviated MSE.</p> + * + * <p>If there are fewer than <strong>numberOfParameters + 1</strong> data pairs in the model, + * or if there is no variation in <code>x</code>, this returns + * <code>Double.NaN</code>.</p> + * + * @return sum of squared deviations of y values + */ + public double getMeanSquareError() { + return this.globalFitInfo[ MSE_IDX]; + } + + /** + * <p>Returns the <a href="http://www.xycoon.com/coefficient1.htm"> + * coefficient of multiple determination</a>, + * usually denoted r-square.</p> + * + * <p><strong>Preconditions</strong>: <ul> + * <li>At least numberOfParameters observations (with at least numberOfParameters different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, {@code Double,NaN} is + * returned. + * </li></ul></p> + * + * @return r-square, a double in the interval [0, 1] + */ + public double getRSquared() { + return this.globalFitInfo[ RSQ_IDX]; + } + + /** + * <p>Returns the adjusted R-squared statistic, defined by the formula <pre> + * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] + * </pre> + * where SSR is the sum of squared residuals}, + * SSTO is the total sum of squares}, n is the number + * of observations and p is the number of parameters estimated (including the intercept).</p> + * + * <p>If the regression is estimated without an intercept term, what is returned is <pre> + * <code> 1 - (1 - {@link #getRSquared()} ) * (n / (n - p)) </code> + * </pre></p> + * + * @return adjusted R-Squared statistic + */ + public double getAdjustedRSquared() { + return this.globalFitInfo[ ADJRSQ_IDX]; + } + + /** + * Returns true if the regression model has been computed including an intercept. + * In this case, the coefficient of the intercept is the first element of the + * {@link #getParameterEstimates() parameter estimates}. + * @return true if the model has an intercept term + */ + public boolean hasIntercept() { + return this.containsConstant; + } + + /** + * Gets the i-jth element of the variance-covariance matrix. + * + * @param i first variable index + * @param j second variable index + * @return the requested variance-covariance matrix entry + */ + private double getVcvElement(int i, int j) { + if (this.isSymmetricVCD) { + if (this.varCovData.length > 1) { + //could be stored in upper or lower triangular + if (i == j) { + return varCovData[i][i]; + } else if (i >= varCovData[j].length) { + return varCovData[i][j]; + } else { + return varCovData[j][i]; + } + } else {//could be in single array + if (i > j) { + return varCovData[0][(i + 1) * i / 2 + j]; + } else { + return varCovData[0][(j + 1) * j / 2 + i]; + } + } + } else { + return this.varCovData[i][j]; + } + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/SimpleRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/SimpleRegression.java new file mode 100644 index 0000000..02bf8f4 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/SimpleRegression.java @@ -0,0 +1,881 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.stat.regression; +import java.io.Serializable; + +import org.apache.commons.math3.distribution.TDistribution; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Precision; + +/** + * Estimates an ordinary least squares regression model + * with one independent variable. + * <p> + * <code> y = intercept + slope * x </code></p> + * <p> + * Standard errors for <code>intercept</code> and <code>slope</code> are + * available as well as ANOVA, r-square and Pearson's r statistics.</p> + * <p> + * Observations (x,y pairs) can be added to the model one at a time or they + * can be provided in a 2-dimensional array. The observations are not stored + * in memory, so there is no limit to the number of observations that can be + * added to the model.</p> + * <p> + * <strong>Usage Notes</strong>: <ul> + * <li> When there are fewer than two observations in the model, or when + * there is no variation in the x values (i.e. all x values are the same) + * all statistics return <code>NaN</code>. At least two observations with + * different x coordinates are required to estimate a bivariate regression + * model. + * </li> + * <li> Getters for the statistics always compute values based on the current + * set of observations -- i.e., you can get statistics, then add more data + * and get updated statistics without using a new instance. There is no + * "compute" method that updates all statistics. Each of the getters performs + * the necessary computations to return the requested statistic. + * </li> + * <li> The intercept term may be suppressed by passing {@code false} to + * the {@link #SimpleRegression(boolean)} constructor. When the + * {@code hasIntercept} property is false, the model is estimated without a + * constant term and {@link #getIntercept()} returns {@code 0}.</li> + * </ul></p> + * + */ +public class SimpleRegression implements Serializable, UpdatingMultipleLinearRegression { + + /** Serializable version identifier */ + private static final long serialVersionUID = -3004689053607543335L; + + /** sum of x values */ + private double sumX = 0d; + + /** total variation in x (sum of squared deviations from xbar) */ + private double sumXX = 0d; + + /** sum of y values */ + private double sumY = 0d; + + /** total variation in y (sum of squared deviations from ybar) */ + private double sumYY = 0d; + + /** sum of products */ + private double sumXY = 0d; + + /** number of observations */ + private long n = 0; + + /** mean of accumulated x values, used in updating formulas */ + private double xbar = 0; + + /** mean of accumulated y values, used in updating formulas */ + private double ybar = 0; + + /** include an intercept or not */ + private final boolean hasIntercept; + // ---------------------Public methods-------------------------------------- + + /** + * Create an empty SimpleRegression instance + */ + public SimpleRegression() { + this(true); + } + /** + * Create a SimpleRegression instance, specifying whether or not to estimate + * an intercept. + * + * <p>Use {@code false} to estimate a model with no intercept. When the + * {@code hasIntercept} property is false, the model is estimated without a + * constant term and {@link #getIntercept()} returns {@code 0}.</p> + * + * @param includeIntercept whether or not to include an intercept term in + * the regression model + */ + public SimpleRegression(boolean includeIntercept) { + super(); + hasIntercept = includeIntercept; + } + + /** + * Adds the observation (x,y) to the regression data set. + * <p> + * Uses updating formulas for means and sums of squares defined in + * "Algorithms for Computing the Sample Variance: Analysis and + * Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J. + * 1983, American Statistician, vol. 37, pp. 242-247, referenced in + * Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985.</p> + * + * + * @param x independent variable value + * @param y dependent variable value + */ + public void addData(final double x,final double y) { + if (n == 0) { + xbar = x; + ybar = y; + } else { + if( hasIntercept ){ + final double fact1 = 1.0 + n; + final double fact2 = n / (1.0 + n); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX += dx * dx * fact2; + sumYY += dy * dy * fact2; + sumXY += dx * dy * fact2; + xbar += dx / fact1; + ybar += dy / fact1; + } + } + if( !hasIntercept ){ + sumXX += x * x ; + sumYY += y * y ; + sumXY += x * y ; + } + sumX += x; + sumY += y; + n++; + } + + /** + * Appends data from another regression calculation to this one. + * + * <p>The mean update formulae are based on a paper written by Philippe + * Pébay: + * <a + * href="http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf"> + * Formulas for Robust, One-Pass Parallel Computation of Covariances and + * Arbitrary-Order Statistical Moments</a>, 2008, Technical Report + * SAND2008-6212, Sandia National Laboratories.</p> + * + * @param reg model to append data from + * @since 3.3 + */ + public void append(SimpleRegression reg) { + if (n == 0) { + xbar = reg.xbar; + ybar = reg.ybar; + sumXX = reg.sumXX; + sumYY = reg.sumYY; + sumXY = reg.sumXY; + } else { + if (hasIntercept) { + final double fact1 = reg.n / (double) (reg.n + n); + final double fact2 = n * reg.n / (double) (reg.n + n); + final double dx = reg.xbar - xbar; + final double dy = reg.ybar - ybar; + sumXX += reg.sumXX + dx * dx * fact2; + sumYY += reg.sumYY + dy * dy * fact2; + sumXY += reg.sumXY + dx * dy * fact2; + xbar += dx * fact1; + ybar += dy * fact1; + }else{ + sumXX += reg.sumXX; + sumYY += reg.sumYY; + sumXY += reg.sumXY; + } + } + sumX += reg.sumX; + sumY += reg.sumY; + n += reg.n; + } + + /** + * Removes the observation (x,y) from the regression data set. + * <p> + * Mirrors the addData method. This method permits the use of + * SimpleRegression instances in streaming mode where the regression + * is applied to a sliding "window" of observations, however the caller is + * responsible for maintaining the set of observations in the window.</p> + * + * The method has no effect if there are no points of data (i.e. n=0) + * + * @param x independent variable value + * @param y dependent variable value + */ + public void removeData(final double x,final double y) { + if (n > 0) { + if (hasIntercept) { + final double fact1 = n - 1.0; + final double fact2 = n / (n - 1.0); + final double dx = x - xbar; + final double dy = y - ybar; + sumXX -= dx * dx * fact2; + sumYY -= dy * dy * fact2; + sumXY -= dx * dy * fact2; + xbar -= dx / fact1; + ybar -= dy / fact1; + } else { + final double fact1 = n - 1.0; + sumXX -= x * x; + sumYY -= y * y; + sumXY -= x * y; + xbar -= x / fact1; + ybar -= y / fact1; + } + sumX -= x; + sumY -= y; + n--; + } + } + + /** + * Adds the observations represented by the elements in + * <code>data</code>. + * <p> + * <code>(data[0][0],data[0][1])</code> will be the first observation, then + * <code>(data[1][0],data[1][1])</code>, etc.</p> + * <p> + * This method does not replace data that has already been added. The + * observations represented by <code>data</code> are added to the existing + * dataset.</p> + * <p> + * To replace all data, use <code>clear()</code> before adding the new + * data.</p> + * + * @param data array of observations to be added + * @throws ModelSpecificationException if the length of {@code data[i]} is not + * greater than or equal to 2 + */ + public void addData(final double[][] data) throws ModelSpecificationException { + for (int i = 0; i < data.length; i++) { + if( data[i].length < 2 ){ + throw new ModelSpecificationException(LocalizedFormats.INVALID_REGRESSION_OBSERVATION, + data[i].length, 2); + } + addData(data[i][0], data[i][1]); + } + } + + /** + * Adds one observation to the regression model. + * + * @param x the independent variables which form the design matrix + * @param y the dependent or response variable + * @throws ModelSpecificationException if the length of {@code x} does not equal + * the number of independent variables in the model + */ + public void addObservation(final double[] x,final double y) + throws ModelSpecificationException { + if( x == null || x.length == 0 ){ + throw new ModelSpecificationException(LocalizedFormats.INVALID_REGRESSION_OBSERVATION,x!=null?x.length:0, 1); + } + addData( x[0], y ); + } + + /** + * Adds a series of observations to the regression model. The lengths of + * x and y must be the same and x must be rectangular. + * + * @param x a series of observations on the independent variables + * @param y a series of observations on the dependent variable + * The length of x and y must be the same + * @throws ModelSpecificationException if {@code x} is not rectangular, does not match + * the length of {@code y} or does not contain sufficient data to estimate the model + */ + public void addObservations(final double[][] x,final double[] y) throws ModelSpecificationException { + if ((x == null) || (y == null) || (x.length != y.length)) { + throw new ModelSpecificationException( + LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, + (x == null) ? 0 : x.length, + (y == null) ? 0 : y.length); + } + boolean obsOk=true; + for( int i = 0 ; i < x.length; i++){ + if( x[i] == null || x[i].length == 0 ){ + obsOk = false; + } + } + if( !obsOk ){ + throw new ModelSpecificationException( + LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + 0, 1); + } + for( int i = 0 ; i < x.length ; i++){ + addData( x[i][0], y[i] ); + } + } + + /** + * Removes observations represented by the elements in <code>data</code>. + * <p> + * If the array is larger than the current n, only the first n elements are + * processed. This method permits the use of SimpleRegression instances in + * streaming mode where the regression is applied to a sliding "window" of + * observations, however the caller is responsible for maintaining the set + * of observations in the window.</p> + * <p> + * To remove all data, use <code>clear()</code>.</p> + * + * @param data array of observations to be removed + */ + public void removeData(double[][] data) { + for (int i = 0; i < data.length && n > 0; i++) { + removeData(data[i][0], data[i][1]); + } + } + + /** + * Clears all data from the model. + */ + public void clear() { + sumX = 0d; + sumXX = 0d; + sumY = 0d; + sumYY = 0d; + sumXY = 0d; + n = 0; + } + + /** + * Returns the number of observations that have been added to the model. + * + * @return n number of observations that have been added. + */ + public long getN() { + return n; + } + + /** + * Returns the "predicted" <code>y</code> value associated with the + * supplied <code>x</code> value, based on the data that has been + * added to the model when this method is activated. + * <p> + * <code> predict(x) = intercept + slope * x </code></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @param x input <code>x</code> value + * @return predicted <code>y</code> value + */ + public double predict(final double x) { + final double b1 = getSlope(); + if (hasIntercept) { + return getIntercept(b1) + b1 * x; + } + return b1 * x; + } + + /** + * Returns the intercept of the estimated regression line, if + * {@link #hasIntercept()} is true; otherwise 0. + * <p> + * The least squares estimate of the intercept is computed using the + * <a href="http://www.xycoon.com/estimation4.htm">normal equations</a>. + * The intercept is sometimes denoted b0.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @return the intercept of the regression line if the model includes an + * intercept; 0 otherwise + * @see #SimpleRegression(boolean) + */ + public double getIntercept() { + return hasIntercept ? getIntercept(getSlope()) : 0.0; + } + + /** + * Returns true if the model includes an intercept term. + * + * @return true if the regression includes an intercept; false otherwise + * @see #SimpleRegression(boolean) + */ + public boolean hasIntercept() { + return hasIntercept; + } + + /** + * Returns the slope of the estimated regression line. + * <p> + * The least squares estimate of the slope is computed using the + * <a href="http://www.xycoon.com/estimation4.htm">normal equations</a>. + * The slope is sometimes denoted b1.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double.NaN</code> is + * returned. + * </li></ul></p> + * + * @return the slope of the regression line + */ + public double getSlope() { + if (n < 2) { + return Double.NaN; //not enough data + } + if (FastMath.abs(sumXX) < 10 * Double.MIN_VALUE) { + return Double.NaN; //not enough variation in x + } + return sumXY / sumXX; + } + + /** + * Returns the <a href="http://www.xycoon.com/SumOfSquares.htm"> + * sum of squared errors</a> (SSE) associated with the regression + * model. + * <p> + * The sum is computed using the computational formula</p> + * <p> + * <code>SSE = SYY - (SXY * SXY / SXX)</code></p> + * <p> + * where <code>SYY</code> is the sum of the squared deviations of the y + * values about their mean, <code>SXX</code> is similarly defined and + * <code>SXY</code> is the sum of the products of x and y mean deviations. + * </p><p> + * The sums are accumulated using the updating algorithm referenced in + * {@link #addData}.</p> + * <p> + * The return value is constrained to be non-negative - i.e., if due to + * rounding errors the computational formula returns a negative result, + * 0 is returned.</p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @return sum of squared errors associated with the regression model + */ + public double getSumSquaredErrors() { + return FastMath.max(0d, sumYY - sumXY * sumXY / sumXX); + } + + /** + * Returns the sum of squared deviations of the y values about their mean. + * <p> + * This is defined as SSTO + * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a>.</p> + * <p> + * If <code>n < 2</code>, this returns <code>Double.NaN</code>.</p> + * + * @return sum of squared deviations of y values + */ + public double getTotalSumSquares() { + if (n < 2) { + return Double.NaN; + } + return sumYY; + } + + /** + * Returns the sum of squared deviations of the x values about their mean. + * + * If <code>n < 2</code>, this returns <code>Double.NaN</code>.</p> + * + * @return sum of squared deviations of x values + */ + public double getXSumSquares() { + if (n < 2) { + return Double.NaN; + } + return sumXX; + } + + /** + * Returns the sum of crossproducts, x<sub>i</sub>*y<sub>i</sub>. + * + * @return sum of cross products + */ + public double getSumOfCrossProducts() { + return sumXY; + } + + /** + * Returns the sum of squared deviations of the predicted y values about + * their mean (which equals the mean of y). + * <p> + * This is usually abbreviated SSR or SSM. It is defined as SSM + * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a></p> + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double.NaN</code> is + * returned. + * </li></ul></p> + * + * @return sum of squared deviations of predicted y values + */ + public double getRegressionSumSquares() { + return getRegressionSumSquares(getSlope()); + } + + /** + * Returns the sum of squared errors divided by the degrees of freedom, + * usually abbreviated MSE. + * <p> + * If there are fewer than <strong>three</strong> data pairs in the model, + * or if there is no variation in <code>x</code>, this returns + * <code>Double.NaN</code>.</p> + * + * @return sum of squared deviations of y values + */ + public double getMeanSquareError() { + if (n < 3) { + return Double.NaN; + } + return hasIntercept ? (getSumSquaredErrors() / (n - 2)) : (getSumSquaredErrors() / (n - 1)); + } + + /** + * Returns <a href="http://mathworld.wolfram.com/CorrelationCoefficient.html"> + * Pearson's product moment correlation coefficient</a>, + * usually denoted r. + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @return Pearson's r + */ + public double getR() { + double b1 = getSlope(); + double result = FastMath.sqrt(getRSquare()); + if (b1 < 0) { + result = -result; + } + return result; + } + + /** + * Returns the <a href="http://www.xycoon.com/coefficient1.htm"> + * coefficient of determination</a>, + * usually denoted r-square. + * <p> + * <strong>Preconditions</strong>: <ul> + * <li>At least two observations (with at least two different x values) + * must have been added before invoking this method. If this method is + * invoked before a model can be estimated, <code>Double,NaN</code> is + * returned. + * </li></ul></p> + * + * @return r-square + */ + public double getRSquare() { + double ssto = getTotalSumSquares(); + return (ssto - getSumSquaredErrors()) / ssto; + } + + /** + * Returns the <a href="http://www.xycoon.com/standarderrorb0.htm"> + * standard error of the intercept estimate</a>, + * usually denoted s(b0). + * <p> + * If there are fewer that <strong>three</strong> observations in the + * model, or if there is no variation in x, this returns + * <code>Double.NaN</code>.</p> Additionally, a <code>Double.NaN</code> is + * returned when the intercept is constrained to be zero + * + * @return standard error associated with intercept estimate + */ + public double getInterceptStdErr() { + if( !hasIntercept ){ + return Double.NaN; + } + return FastMath.sqrt( + getMeanSquareError() * ((1d / n) + (xbar * xbar) / sumXX)); + } + + /** + * Returns the <a href="http://www.xycoon.com/standerrorb(1).htm">standard + * error of the slope estimate</a>, + * usually denoted s(b1). + * <p> + * If there are fewer that <strong>three</strong> data pairs in the model, + * or if there is no variation in x, this returns <code>Double.NaN</code>. + * </p> + * + * @return standard error associated with slope estimate + */ + public double getSlopeStdErr() { + return FastMath.sqrt(getMeanSquareError() / sumXX); + } + + /** + * Returns the half-width of a 95% confidence interval for the slope + * estimate. + * <p> + * The 95% confidence interval is</p> + * <p> + * <code>(getSlope() - getSlopeConfidenceInterval(), + * getSlope() + getSlopeConfidenceInterval())</code></p> + * <p> + * If there are fewer that <strong>three</strong> observations in the + * model, or if there is no variation in x, this returns + * <code>Double.NaN</code>.</p> + * <p> + * <strong>Usage Note</strong>:<br> + * The validity of this statistic depends on the assumption that the + * observations included in the model are drawn from a + * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html"> + * Bivariate Normal Distribution</a>.</p> + * + * @return half-width of 95% confidence interval for the slope estimate + * @throws OutOfRangeException if the confidence interval can not be computed. + */ + public double getSlopeConfidenceInterval() throws OutOfRangeException { + return getSlopeConfidenceInterval(0.05d); + } + + /** + * Returns the half-width of a (100-100*alpha)% confidence interval for + * the slope estimate. + * <p> + * The (100-100*alpha)% confidence interval is </p> + * <p> + * <code>(getSlope() - getSlopeConfidenceInterval(), + * getSlope() + getSlopeConfidenceInterval())</code></p> + * <p> + * To request, for example, a 99% confidence interval, use + * <code>alpha = .01</code></p> + * <p> + * <strong>Usage Note</strong>:<br> + * The validity of this statistic depends on the assumption that the + * observations included in the model are drawn from a + * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html"> + * Bivariate Normal Distribution</a>.</p> + * <p> + * <strong> Preconditions:</strong><ul> + * <li>If there are fewer that <strong>three</strong> observations in the + * model, or if there is no variation in x, this returns + * <code>Double.NaN</code>. + * </li> + * <li><code>(0 < alpha < 1)</code>; otherwise an + * <code>OutOfRangeException</code> is thrown. + * </li></ul></p> + * + * @param alpha the desired significance level + * @return half-width of 95% confidence interval for the slope estimate + * @throws OutOfRangeException if the confidence interval can not be computed. + */ + public double getSlopeConfidenceInterval(final double alpha) + throws OutOfRangeException { + if (n < 3) { + return Double.NaN; + } + if (alpha >= 1 || alpha <= 0) { + throw new OutOfRangeException(LocalizedFormats.SIGNIFICANCE_LEVEL, + alpha, 0, 1); + } + // No advertised NotStrictlyPositiveException here - will return NaN above + TDistribution distribution = new TDistribution(n - 2); + return getSlopeStdErr() * + distribution.inverseCumulativeProbability(1d - alpha / 2d); + } + + /** + * Returns the significance level of the slope (equiv) correlation. + * <p> + * Specifically, the returned value is the smallest <code>alpha</code> + * such that the slope confidence interval with significance level + * equal to <code>alpha</code> does not include <code>0</code>. + * On regression output, this is often denoted <code>Prob(|t| > 0)</code> + * </p><p> + * <strong>Usage Note</strong>:<br> + * The validity of this statistic depends on the assumption that the + * observations included in the model are drawn from a + * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html"> + * Bivariate Normal Distribution</a>.</p> + * <p> + * If there are fewer that <strong>three</strong> observations in the + * model, or if there is no variation in x, this returns + * <code>Double.NaN</code>.</p> + * + * @return significance level for slope/correlation + * @throws org.apache.commons.math3.exception.MaxCountExceededException + * if the significance level can not be computed. + */ + public double getSignificance() { + if (n < 3) { + return Double.NaN; + } + // No advertised NotStrictlyPositiveException here - will return NaN above + TDistribution distribution = new TDistribution(n - 2); + return 2d * (1.0 - distribution.cumulativeProbability( + FastMath.abs(getSlope()) / getSlopeStdErr())); + } + + // ---------------------Private methods----------------------------------- + + /** + * Returns the intercept of the estimated regression line, given the slope. + * <p> + * Will return <code>NaN</code> if slope is <code>NaN</code>.</p> + * + * @param slope current slope + * @return the intercept of the regression line + */ + private double getIntercept(final double slope) { + if( hasIntercept){ + return (sumY - slope * sumX) / n; + } + return 0.0; + } + + /** + * Computes SSR from b1. + * + * @param slope regression slope estimate + * @return sum of squared deviations of predicted y values + */ + private double getRegressionSumSquares(final double slope) { + return slope * slope * sumXX; + } + + /** + * Performs a regression on data present in buffers and outputs a RegressionResults object. + * + * <p>If there are fewer than 3 observations in the model and {@code hasIntercept} is true + * a {@code NoDataException} is thrown. If there is no intercept term, the model must + * contain at least 2 observations.</p> + * + * @return RegressionResults acts as a container of regression output + * @throws ModelSpecificationException if the model is not correctly specified + * @throws NoDataException if there is not sufficient data in the model to + * estimate the regression parameters + */ + public RegressionResults regress() throws ModelSpecificationException, NoDataException { + if (hasIntercept) { + if (n < 3) { + throw new NoDataException(LocalizedFormats.NOT_ENOUGH_DATA_REGRESSION); + } + if (FastMath.abs(sumXX) > Precision.SAFE_MIN) { + final double[] params = new double[] { getIntercept(), getSlope() }; + final double mse = getMeanSquareError(); + final double _syy = sumYY + sumY * sumY / n; + final double[] vcv = new double[] { mse * (xbar * xbar / sumXX + 1.0 / n), -xbar * mse / sumXX, mse / sumXX }; + return new RegressionResults(params, new double[][] { vcv }, true, n, 2, sumY, _syy, getSumSquaredErrors(), true, + false); + } else { + final double[] params = new double[] { sumY / n, Double.NaN }; + // final double mse = getMeanSquareError(); + final double[] vcv = new double[] { ybar / (n - 1.0), Double.NaN, Double.NaN }; + return new RegressionResults(params, new double[][] { vcv }, true, n, 1, sumY, sumYY, getSumSquaredErrors(), true, + false); + } + } else { + if (n < 2) { + throw new NoDataException(LocalizedFormats.NOT_ENOUGH_DATA_REGRESSION); + } + if (!Double.isNaN(sumXX)) { + final double[] vcv = new double[] { getMeanSquareError() / sumXX }; + final double[] params = new double[] { sumXY / sumXX }; + return new RegressionResults(params, new double[][] { vcv }, true, n, 1, sumY, sumYY, getSumSquaredErrors(), false, + false); + } else { + final double[] vcv = new double[] { Double.NaN }; + final double[] params = new double[] { Double.NaN }; + return new RegressionResults(params, new double[][] { vcv }, true, n, 1, Double.NaN, Double.NaN, Double.NaN, false, + false); + } + } + } + + /** + * Performs a regression on data present in buffers including only regressors + * indexed in variablesToInclude and outputs a RegressionResults object + * @param variablesToInclude an array of indices of regressors to include + * @return RegressionResults acts as a container of regression output + * @throws MathIllegalArgumentException if the variablesToInclude array is null or zero length + * @throws OutOfRangeException if a requested variable is not present in model + */ + public RegressionResults regress(int[] variablesToInclude) throws MathIllegalArgumentException{ + if( variablesToInclude == null || variablesToInclude.length == 0){ + throw new MathIllegalArgumentException(LocalizedFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED); + } + if( variablesToInclude.length > 2 || (variablesToInclude.length > 1 && !hasIntercept) ){ + throw new ModelSpecificationException( + LocalizedFormats.ARRAY_SIZE_EXCEEDS_MAX_VARIABLES, + (variablesToInclude.length > 1 && !hasIntercept) ? 1 : 2); + } + + if( hasIntercept ){ + if( variablesToInclude.length == 2 ){ + if( variablesToInclude[0] == 1 ){ + throw new ModelSpecificationException(LocalizedFormats.NOT_INCREASING_SEQUENCE); + }else if( variablesToInclude[0] != 0 ){ + throw new OutOfRangeException( variablesToInclude[0], 0,1 ); + } + if( variablesToInclude[1] != 1){ + throw new OutOfRangeException( variablesToInclude[0], 0,1 ); + } + return regress(); + }else{ + if( variablesToInclude[0] != 1 && variablesToInclude[0] != 0 ){ + throw new OutOfRangeException( variablesToInclude[0],0,1 ); + } + final double _mean = sumY * sumY / n; + final double _syy = sumYY + _mean; + if( variablesToInclude[0] == 0 ){ + //just the mean + final double[] vcv = new double[]{ sumYY/(((n-1)*n)) }; + final double[] params = new double[]{ ybar }; + return new RegressionResults( + params, new double[][]{vcv}, true, n, 1, + sumY, _syy+_mean, sumYY,true,false); + + }else if( variablesToInclude[0] == 1){ + //final double _syy = sumYY + sumY * sumY / ((double) n); + final double _sxx = sumXX + sumX * sumX / n; + final double _sxy = sumXY + sumX * sumY / n; + final double _sse = FastMath.max(0d, _syy - _sxy * _sxy / _sxx); + final double _mse = _sse/((n-1)); + if( !Double.isNaN(_sxx) ){ + final double[] vcv = new double[]{ _mse / _sxx }; + final double[] params = new double[]{ _sxy/_sxx }; + return new RegressionResults( + params, new double[][]{vcv}, true, n, 1, + sumY, _syy, _sse,false,false); + }else{ + final double[] vcv = new double[]{Double.NaN }; + final double[] params = new double[]{ Double.NaN }; + return new RegressionResults( + params, new double[][]{vcv}, true, n, 1, + Double.NaN, Double.NaN, Double.NaN,false,false); + } + } + } + }else{ + if( variablesToInclude[0] != 0 ){ + throw new OutOfRangeException(variablesToInclude[0],0,0); + } + return regress(); + } + + return null; + } +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/UpdatingMultipleLinearRegression.java b/src/main/java/org/apache/commons/math3/stat/regression/UpdatingMultipleLinearRegression.java new file mode 100644 index 0000000..ebefc31 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/UpdatingMultipleLinearRegression.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.regression; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NoDataException; + +/** + * An interface for regression models allowing for dynamic updating of the data. + * That is, the entire data set need not be loaded into memory. As observations + * become available, they can be added to the regression model and an updated + * estimate regression statistics can be calculated. + * + * @since 3.0 + */ +public interface UpdatingMultipleLinearRegression { + + /** + * Returns true if a constant has been included false otherwise. + * + * @return true if constant exists, false otherwise + */ + boolean hasIntercept(); + + /** + * Returns the number of observations added to the regression model. + * + * @return Number of observations + */ + long getN(); + + /** + * Adds one observation to the regression model. + * + * @param x the independent variables which form the design matrix + * @param y the dependent or response variable + * @throws ModelSpecificationException if the length of {@code x} does not equal + * the number of independent variables in the model + */ + void addObservation(double[] x, double y) throws ModelSpecificationException; + + /** + * Adds a series of observations to the regression model. The lengths of + * x and y must be the same and x must be rectangular. + * + * @param x a series of observations on the independent variables + * @param y a series of observations on the dependent variable + * The length of x and y must be the same + * @throws ModelSpecificationException if {@code x} is not rectangular, does not match + * the length of {@code y} or does not contain sufficient data to estimate the model + */ + void addObservations(double[][] x, double[] y) throws ModelSpecificationException; + + /** + * Clears internal buffers and resets the regression model. This means all + * data and derived values are initialized + */ + void clear(); + + + /** + * Performs a regression on data present in buffers and outputs a RegressionResults object + * @return RegressionResults acts as a container of regression output + * @throws ModelSpecificationException if the model is not correctly specified + * @throws NoDataException if there is not sufficient data in the model to + * estimate the regression parameters + */ + RegressionResults regress() throws ModelSpecificationException, NoDataException; + + /** + * Performs a regression on data present in buffers including only regressors + * indexed in variablesToInclude and outputs a RegressionResults object + * @param variablesToInclude an array of indices of regressors to include + * @return RegressionResults acts as a container of regression output + * @throws ModelSpecificationException if the model is not correctly specified + * @throws MathIllegalArgumentException if the variablesToInclude array is null or zero length + */ + RegressionResults regress(int[] variablesToInclude) throws ModelSpecificationException, MathIllegalArgumentException; +} diff --git a/src/main/java/org/apache/commons/math3/stat/regression/package-info.java b/src/main/java/org/apache/commons/math3/stat/regression/package-info.java new file mode 100644 index 0000000..fbc0e12 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/regression/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * + * Statistical routines involving multivariate data. + * + */ +package org.apache.commons.math3.stat.regression; |