summaryrefslogtreecommitdiff
path: root/src/main/java/org/apache/commons/math3/util/CombinatoricsUtils.java
blob: 9743a6be1cd2e7f1b491f00e629b8f058e163bf5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math3.util;

import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.util.LocalizedFormats;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Combinatorial utilities.
 *
 * @since 3.3
 */
public final class CombinatoricsUtils {

    /** All long-representable factorials */
    static final long[] FACTORIALS =
            new long[] {
                1l,
                1l,
                2l,
                6l,
                24l,
                120l,
                720l,
                5040l,
                40320l,
                362880l,
                3628800l,
                39916800l,
                479001600l,
                6227020800l,
                87178291200l,
                1307674368000l,
                20922789888000l,
                355687428096000l,
                6402373705728000l,
                121645100408832000l,
                2432902008176640000l
            };

    /** Stirling numbers of the second kind. */
    static final AtomicReference<long[][]> STIRLING_S2 = new AtomicReference<long[][]>(null);

    /** Private constructor (class contains only static methods). */
    private CombinatoricsUtils() {}

    /**
     * Returns an exact representation of the <a
     * href="http://mathworld.wolfram.com/BinomialCoefficient.html">Binomial Coefficient</a>,
     * "{@code n choose k}", the number of {@code k}-element subsets that can be selected from an
     * {@code n}-element set.
     *
     * <p><Strong>Preconditions</strong>:
     *
     * <ul>
     *   <li>{@code 0 <= k <= n } (otherwise {@code MathIllegalArgumentException} is thrown)
     *   <li>The result is small enough to fit into a {@code long}. The largest value of {@code n}
     *       for which all coefficients are {@code < Long.MAX_VALUE} is 66. If the computed value
     *       exceeds {@code Long.MAX_VALUE} a {@code MathArithMeticException} is thrown.
     * </ul>
     *
     * @param n the size of the set
     * @param k the size of the subsets to be counted
     * @return {@code n choose k}
     * @throws NotPositiveException if {@code n < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     * @throws MathArithmeticException if the result is too large to be represented by a long
     *     integer.
     */
    public static long binomialCoefficient(final int n, final int k)
            throws NotPositiveException, NumberIsTooLargeException, MathArithmeticException {
        CombinatoricsUtils.checkBinomial(n, k);
        if ((n == k) || (k == 0)) {
            return 1;
        }
        if ((k == 1) || (k == n - 1)) {
            return n;
        }
        // Use symmetry for large k
        if (k > n / 2) {
            return binomialCoefficient(n, n - k);
        }

        // We use the formula
        // (n choose k) = n! / (n-k)! / k!
        // (n choose k) == ((n-k+1)*...*n) / (1*...*k)
        // which could be written
        // (n choose k) == (n-1 choose k-1) * n / k
        long result = 1;
        if (n <= 61) {
            // For n <= 61, the naive implementation cannot overflow.
            int i = n - k + 1;
            for (int j = 1; j <= k; j++) {
                result = result * i / j;
                i++;
            }
        } else if (n <= 66) {
            // For n > 61 but n <= 66, the result cannot overflow,
            // but we must take care not to overflow intermediate values.
            int i = n - k + 1;
            for (int j = 1; j <= k; j++) {
                // We know that (result * i) is divisible by j,
                // but (result * i) may overflow, so we split j:
                // Filter out the gcd, d, so j/d and i/d are integer.
                // result is divisible by (j/d) because (j/d)
                // is relative prime to (i/d) and is a divisor of
                // result * (i/d).
                final long d = ArithmeticUtils.gcd(i, j);
                result = (result / (j / d)) * (i / d);
                i++;
            }
        } else {
            // For n > 66, a result overflow might occur, so we check
            // the multiplication, taking care to not overflow
            // unnecessary.
            int i = n - k + 1;
            for (int j = 1; j <= k; j++) {
                final long d = ArithmeticUtils.gcd(i, j);
                result = ArithmeticUtils.mulAndCheck(result / (j / d), i / d);
                i++;
            }
        }
        return result;
    }

    /**
     * Returns a {@code double} representation of the <a
     * href="http://mathworld.wolfram.com/BinomialCoefficient.html">Binomial Coefficient</a>,
     * "{@code n choose k}", the number of {@code k}-element subsets that can be selected from an
     * {@code n}-element set.
     *
     * <p><Strong>Preconditions</strong>:
     *
     * <ul>
     *   <li>{@code 0 <= k <= n } (otherwise {@code IllegalArgumentException} is thrown)
     *   <li>The result is small enough to fit into a {@code double}. The largest value of {@code n}
     *       for which all coefficients are less than Double.MAX_VALUE is 1029. If the computed
     *       value exceeds Double.MAX_VALUE, Double.POSITIVE_INFINITY is returned
     * </ul>
     *
     * @param n the size of the set
     * @param k the size of the subsets to be counted
     * @return {@code n choose k}
     * @throws NotPositiveException if {@code n < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     * @throws MathArithmeticException if the result is too large to be represented by a long
     *     integer.
     */
    public static double binomialCoefficientDouble(final int n, final int k)
            throws NotPositiveException, NumberIsTooLargeException, MathArithmeticException {
        CombinatoricsUtils.checkBinomial(n, k);
        if ((n == k) || (k == 0)) {
            return 1d;
        }
        if ((k == 1) || (k == n - 1)) {
            return n;
        }
        if (k > n / 2) {
            return binomialCoefficientDouble(n, n - k);
        }
        if (n < 67) {
            return binomialCoefficient(n, k);
        }

        double result = 1d;
        for (int i = 1; i <= k; i++) {
            result *= (double) (n - k + i) / (double) i;
        }

        return FastMath.floor(result + 0.5);
    }

    /**
     * Returns the natural {@code log} of the <a
     * href="http://mathworld.wolfram.com/BinomialCoefficient.html">Binomial Coefficient</a>,
     * "{@code n choose k}", the number of {@code k}-element subsets that can be selected from an
     * {@code n}-element set.
     *
     * <p><Strong>Preconditions</strong>:
     *
     * <ul>
     *   <li>{@code 0 <= k <= n } (otherwise {@code MathIllegalArgumentException} is thrown)
     * </ul>
     *
     * @param n the size of the set
     * @param k the size of the subsets to be counted
     * @return {@code n choose k}
     * @throws NotPositiveException if {@code n < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     * @throws MathArithmeticException if the result is too large to be represented by a long
     *     integer.
     */
    public static double binomialCoefficientLog(final int n, final int k)
            throws NotPositiveException, NumberIsTooLargeException, MathArithmeticException {
        CombinatoricsUtils.checkBinomial(n, k);
        if ((n == k) || (k == 0)) {
            return 0;
        }
        if ((k == 1) || (k == n - 1)) {
            return FastMath.log(n);
        }

        /*
         * For values small enough to do exact integer computation,
         * return the log of the exact value
         */
        if (n < 67) {
            return FastMath.log(binomialCoefficient(n, k));
        }

        /*
         * Return the log of binomialCoefficientDouble for values that will not
         * overflow binomialCoefficientDouble
         */
        if (n < 1030) {
            return FastMath.log(binomialCoefficientDouble(n, k));
        }

        if (k > n / 2) {
            return binomialCoefficientLog(n, n - k);
        }

        /*
         * Sum logs for values that could overflow
         */
        double logSum = 0;

        // n!/(n-k)!
        for (int i = n - k + 1; i <= n; i++) {
            logSum += FastMath.log(i);
        }

        // divide by k!
        for (int i = 2; i <= k; i++) {
            logSum -= FastMath.log(i);
        }

        return logSum;
    }

    /**
     * Returns n!. Shorthand for {@code n} <a href="http://mathworld.wolfram.com/Factorial.html">
     * Factorial</a>, the product of the numbers {@code 1,...,n}.
     *
     * <p><Strong>Preconditions</strong>:
     *
     * <ul>
     *   <li>{@code n >= 0} (otherwise {@code MathIllegalArgumentException} is thrown)
     *   <li>The result is small enough to fit into a {@code long}. The largest value of {@code n}
     *       for which {@code n!} does not exceed Long.MAX_VALUE} is 20. If the computed value
     *       exceeds {@code Long.MAX_VALUE} an {@code MathArithMeticException } is thrown.
     * </ul>
     *
     * @param n argument
     * @return {@code n!}
     * @throws MathArithmeticException if the result is too large to be represented by a {@code
     *     long}.
     * @throws NotPositiveException if {@code n < 0}.
     * @throws MathArithmeticException if {@code n > 20}: The factorial value is too large to fit in
     *     a {@code long}.
     */
    public static long factorial(final int n) throws NotPositiveException, MathArithmeticException {
        if (n < 0) {
            throw new NotPositiveException(LocalizedFormats.FACTORIAL_NEGATIVE_PARAMETER, n);
        }
        if (n > 20) {
            throw new MathArithmeticException();
        }
        return FACTORIALS[n];
    }

    /**
     * Compute n!, the<a href="http://mathworld.wolfram.com/Factorial.html">factorial</a> of {@code
     * n} (the product of the numbers 1 to n), as a {@code double}. The result should be small
     * enough to fit into a {@code double}: The largest {@code n} for which {@code n!} does not
     * exceed {@code Double.MAX_VALUE} is 170. If the computed value exceeds {@code
     * Double.MAX_VALUE}, {@code Double.POSITIVE_INFINITY} is returned.
     *
     * @param n Argument.
     * @return {@code n!}
     * @throws NotPositiveException if {@code n < 0}.
     */
    public static double factorialDouble(final int n) throws NotPositiveException {
        if (n < 0) {
            throw new NotPositiveException(LocalizedFormats.FACTORIAL_NEGATIVE_PARAMETER, n);
        }
        if (n < 21) {
            return FACTORIALS[n];
        }
        return FastMath.floor(FastMath.exp(CombinatoricsUtils.factorialLog(n)) + 0.5);
    }

    /**
     * Compute the natural logarithm of the factorial of {@code n}.
     *
     * @param n Argument.
     * @return {@code n!}
     * @throws NotPositiveException if {@code n < 0}.
     */
    public static double factorialLog(final int n) throws NotPositiveException {
        if (n < 0) {
            throw new NotPositiveException(LocalizedFormats.FACTORIAL_NEGATIVE_PARAMETER, n);
        }
        if (n < 21) {
            return FastMath.log(FACTORIALS[n]);
        }
        double logSum = 0;
        for (int i = 2; i <= n; i++) {
            logSum += FastMath.log(i);
        }
        return logSum;
    }

    /**
     * Returns the <a href="http://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html">
     * Stirling number of the second kind</a>, "{@code S(n,k)}", the number of ways of partitioning
     * an {@code n}-element set into {@code k} non-empty subsets.
     *
     * <p>The preconditions are {@code 0 <= k <= n } (otherwise {@code NotPositiveException} is
     * thrown)
     *
     * @param n the size of the set
     * @param k the number of non-empty subsets
     * @return {@code S(n,k)}
     * @throws NotPositiveException if {@code k < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     * @throws MathArithmeticException if some overflow happens, typically for n exceeding 25 and k
     *     between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow)
     * @since 3.1
     */
    public static long stirlingS2(final int n, final int k)
            throws NotPositiveException, NumberIsTooLargeException, MathArithmeticException {
        if (k < 0) {
            throw new NotPositiveException(k);
        }
        if (k > n) {
            throw new NumberIsTooLargeException(k, n, true);
        }

        long[][] stirlingS2 = STIRLING_S2.get();

        if (stirlingS2 == null) {
            // the cache has never been initialized, compute the first numbers
            // by direct recurrence relation

            // as S(26,9) = 11201516780955125625 is larger than Long.MAX_VALUE
            // we must stop computation at row 26
            final int maxIndex = 26;
            stirlingS2 = new long[maxIndex][];
            stirlingS2[0] = new long[] {1l};
            for (int i = 1; i < stirlingS2.length; ++i) {
                stirlingS2[i] = new long[i + 1];
                stirlingS2[i][0] = 0;
                stirlingS2[i][1] = 1;
                stirlingS2[i][i] = 1;
                for (int j = 2; j < i; ++j) {
                    stirlingS2[i][j] = j * stirlingS2[i - 1][j] + stirlingS2[i - 1][j - 1];
                }
            }

            // atomically save the cache
            STIRLING_S2.compareAndSet(null, stirlingS2);
        }

        if (n < stirlingS2.length) {
            // the number is in the small cache
            return stirlingS2[n][k];
        } else {
            // use explicit formula to compute the number without caching it
            if (k == 0) {
                return 0;
            } else if (k == 1 || k == n) {
                return 1;
            } else if (k == 2) {
                return (1l << (n - 1)) - 1l;
            } else if (k == n - 1) {
                return binomialCoefficient(n, 2);
            } else {
                // definition formula: note that this may trigger some overflow
                long sum = 0;
                long sign = ((k & 0x1) == 0) ? 1 : -1;
                for (int j = 1; j <= k; ++j) {
                    sign = -sign;
                    sum += sign * binomialCoefficient(k, j) * ArithmeticUtils.pow(j, n);
                    if (sum < 0) {
                        // there was an overflow somewhere
                        throw new MathArithmeticException(
                                LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN,
                                n,
                                0,
                                stirlingS2.length - 1);
                    }
                }
                return sum / factorial(k);
            }
        }
    }

    /**
     * Returns an iterator whose range is the k-element subsets of {0, ..., n - 1} represented as
     * {@code int[]} arrays.
     *
     * <p>The arrays returned by the iterator are sorted in descending order and they are visited in
     * lexicographic order with significance from right to left. For example,
     * combinationsIterator(4, 2) returns an Iterator that will generate the following sequence of
     * arrays on successive calls to {@code next()}:
     *
     * <p>{@code [0, 1], [0, 2], [1, 2], [0, 3], [1, 3], [2, 3]}
     *
     * <p>If {@code k == 0} an Iterator containing an empty array is returned and if {@code k == n}
     * an Iterator containing [0, ..., n -1] is returned.
     *
     * @param n Size of the set from which subsets are selected.
     * @param k Size of the subsets to be enumerated.
     * @return an {@link Iterator iterator} over the k-sets in n.
     * @throws NotPositiveException if {@code n < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     */
    public static Iterator<int[]> combinationsIterator(int n, int k) {
        return new Combinations(n, k).iterator();
    }

    /**
     * Check binomial preconditions.
     *
     * @param n Size of the set.
     * @param k Size of the subsets to be counted.
     * @throws NotPositiveException if {@code n < 0}.
     * @throws NumberIsTooLargeException if {@code k > n}.
     */
    public static void checkBinomial(final int n, final int k)
            throws NumberIsTooLargeException, NotPositiveException {
        if (n < k) {
            throw new NumberIsTooLargeException(
                    LocalizedFormats.BINOMIAL_INVALID_PARAMETERS_ORDER, k, n, true);
        }
        if (n < 0) {
            throw new NotPositiveException(LocalizedFormats.BINOMIAL_NEGATIVE_PARAMETER, n);
        }
    }
}