diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/analysis/FunctionUtils.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/analysis/FunctionUtils.java | 823 |
1 files changed, 823 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/analysis/FunctionUtils.java b/src/main/java/org/apache/commons/math3/analysis/FunctionUtils.java new file mode 100644 index 0000000..b3cdb9e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/analysis/FunctionUtils.java @@ -0,0 +1,823 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.analysis; + +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; +import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableFunction; +import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction; +import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; +import org.apache.commons.math3.analysis.function.Identity; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.exception.util.LocalizedFormats; + +/** + * Utilities for manipulating function objects. + * + * @since 3.0 + */ +public class FunctionUtils { + /** Class only contains static methods. */ + private FunctionUtils() {} + + /** + * Composes functions. + * + * <p>The functions in the argument list are composed sequentially, in the given order. For + * example, compose(f1,f2,f3) acts like f1(f2(f3(x))). + * + * @param f List of functions. + * @return the composite function. + */ + public static UnivariateFunction compose(final UnivariateFunction... f) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = x; + for (int i = f.length - 1; i >= 0; i--) { + r = f[i].value(r); + } + return r; + } + }; + } + + /** + * Composes functions. + * + * <p>The functions in the argument list are composed sequentially, in the given order. For + * example, compose(f1,f2,f3) acts like f1(f2(f3(x))). + * + * @param f List of functions. + * @return the composite function. + * @since 3.1 + */ + public static UnivariateDifferentiableFunction compose( + final UnivariateDifferentiableFunction... f) { + return new UnivariateDifferentiableFunction() { + + /** {@inheritDoc} */ + public double value(final double t) { + double r = t; + for (int i = f.length - 1; i >= 0; i--) { + r = f[i].value(r); + } + return r; + } + + /** {@inheritDoc} */ + public DerivativeStructure value(final DerivativeStructure t) { + DerivativeStructure r = t; + for (int i = f.length - 1; i >= 0; i--) { + r = f[i].value(r); + } + return r; + } + }; + } + + /** + * Composes functions. + * + * <p>The functions in the argument list are composed sequentially, in the given order. For + * example, compose(f1,f2,f3) acts like f1(f2(f3(x))). + * + * @param f List of functions. + * @return the composite function. + * @deprecated as of 3.1 replaced by {@link #compose(UnivariateDifferentiableFunction...)} + */ + @Deprecated + public static DifferentiableUnivariateFunction compose( + final DifferentiableUnivariateFunction... f) { + return new DifferentiableUnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = x; + for (int i = f.length - 1; i >= 0; i--) { + r = f[i].value(r); + } + return r; + } + + /** {@inheritDoc} */ + public UnivariateFunction derivative() { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double p = 1; + double r = x; + for (int i = f.length - 1; i >= 0; i--) { + p *= f[i].derivative().value(r); + r = f[i].value(r); + } + return p; + } + }; + } + }; + } + + /** + * Adds functions. + * + * @param f List of functions. + * @return a function that computes the sum of the functions. + */ + public static UnivariateFunction add(final UnivariateFunction... f) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = f[0].value(x); + for (int i = 1; i < f.length; i++) { + r += f[i].value(x); + } + return r; + } + }; + } + + /** + * Adds functions. + * + * @param f List of functions. + * @return a function that computes the sum of the functions. + * @since 3.1 + */ + public static UnivariateDifferentiableFunction add( + final UnivariateDifferentiableFunction... f) { + return new UnivariateDifferentiableFunction() { + + /** {@inheritDoc} */ + public double value(final double t) { + double r = f[0].value(t); + for (int i = 1; i < f.length; i++) { + r += f[i].value(t); + } + return r; + } + + /** + * {@inheritDoc} + * + * @throws DimensionMismatchException if functions are not consistent with each other + */ + public DerivativeStructure value(final DerivativeStructure t) + throws DimensionMismatchException { + DerivativeStructure r = f[0].value(t); + for (int i = 1; i < f.length; i++) { + r = r.add(f[i].value(t)); + } + return r; + } + }; + } + + /** + * Adds functions. + * + * @param f List of functions. + * @return a function that computes the sum of the functions. + * @deprecated as of 3.1 replaced by {@link #add(UnivariateDifferentiableFunction...)} + */ + @Deprecated + public static DifferentiableUnivariateFunction add( + final DifferentiableUnivariateFunction... f) { + return new DifferentiableUnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = f[0].value(x); + for (int i = 1; i < f.length; i++) { + r += f[i].value(x); + } + return r; + } + + /** {@inheritDoc} */ + public UnivariateFunction derivative() { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = f[0].derivative().value(x); + for (int i = 1; i < f.length; i++) { + r += f[i].derivative().value(x); + } + return r; + } + }; + } + }; + } + + /** + * Multiplies functions. + * + * @param f List of functions. + * @return a function that computes the product of the functions. + */ + public static UnivariateFunction multiply(final UnivariateFunction... f) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = f[0].value(x); + for (int i = 1; i < f.length; i++) { + r *= f[i].value(x); + } + return r; + } + }; + } + + /** + * Multiplies functions. + * + * @param f List of functions. + * @return a function that computes the product of the functions. + * @since 3.1 + */ + public static UnivariateDifferentiableFunction multiply( + final UnivariateDifferentiableFunction... f) { + return new UnivariateDifferentiableFunction() { + + /** {@inheritDoc} */ + public double value(final double t) { + double r = f[0].value(t); + for (int i = 1; i < f.length; i++) { + r *= f[i].value(t); + } + return r; + } + + /** {@inheritDoc} */ + public DerivativeStructure value(final DerivativeStructure t) { + DerivativeStructure r = f[0].value(t); + for (int i = 1; i < f.length; i++) { + r = r.multiply(f[i].value(t)); + } + return r; + } + }; + } + + /** + * Multiplies functions. + * + * @param f List of functions. + * @return a function that computes the product of the functions. + * @deprecated as of 3.1 replaced by {@link #multiply(UnivariateDifferentiableFunction...)} + */ + @Deprecated + public static DifferentiableUnivariateFunction multiply( + final DifferentiableUnivariateFunction... f) { + return new DifferentiableUnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double r = f[0].value(x); + for (int i = 1; i < f.length; i++) { + r *= f[i].value(x); + } + return r; + } + + /** {@inheritDoc} */ + public UnivariateFunction derivative() { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + double sum = 0; + for (int i = 0; i < f.length; i++) { + double prod = f[i].derivative().value(x); + for (int j = 0; j < f.length; j++) { + if (i != j) { + prod *= f[j].value(x); + } + } + sum += prod; + } + return sum; + } + }; + } + }; + } + + /** + * Returns the univariate function {@code h(x) = combiner(f(x), g(x)).} + * + * @param combiner Combiner function. + * @param f Function. + * @param g Function. + * @return the composite function. + */ + public static UnivariateFunction combine( + final BivariateFunction combiner, + final UnivariateFunction f, + final UnivariateFunction g) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + return combiner.value(f.value(x), g.value(x)); + } + }; + } + + /** + * Returns a MultivariateFunction h(x[]) defined by + * + * <pre> <code> + * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1])) + * </code></pre> + * + * @param combiner Combiner function. + * @param f Function. + * @param initialValue Initial value. + * @return a collector function. + */ + public static MultivariateFunction collector( + final BivariateFunction combiner, + final UnivariateFunction f, + final double initialValue) { + return new MultivariateFunction() { + /** {@inheritDoc} */ + public double value(double[] point) { + double result = combiner.value(initialValue, f.value(point[0])); + for (int i = 1; i < point.length; i++) { + result = combiner.value(result, f.value(point[i])); + } + return result; + } + }; + } + + /** + * Returns a MultivariateFunction h(x[]) defined by + * + * <pre> <code> + * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1]) + * </code></pre> + * + * @param combiner Combiner function. + * @param initialValue Initial value. + * @return a collector function. + */ + public static MultivariateFunction collector( + final BivariateFunction combiner, final double initialValue) { + return collector(combiner, new Identity(), initialValue); + } + + /** + * Creates a unary function by fixing the first argument of a binary function. + * + * @param f Binary function. + * @param fixed value to which the first argument of {@code f} is set. + * @return the unary function h(x) = f(fixed, x) + */ + public static UnivariateFunction fix1stArgument(final BivariateFunction f, final double fixed) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + return f.value(fixed, x); + } + }; + } + + /** + * Creates a unary function by fixing the second argument of a binary function. + * + * @param f Binary function. + * @param fixed value to which the second argument of {@code f} is set. + * @return the unary function h(x) = f(x, fixed) + */ + public static UnivariateFunction fix2ndArgument(final BivariateFunction f, final double fixed) { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(double x) { + return f.value(x, fixed); + } + }; + } + + /** + * Samples the specified univariate real function on the specified interval. + * + * <p>The interval is divided equally into {@code n} sections and sample points are taken from + * {@code min} to {@code max - (max - min) / n}; therefore {@code f} is not sampled at the upper + * bound {@code max}. + * + * @param f Function to be sampled + * @param min Lower bound of the interval (included). + * @param max Upper bound of the interval (excluded). + * @param n Number of sample points. + * @return the array of samples. + * @throws NumberIsTooLargeException if the lower bound {@code min} is greater than, or equal to + * the upper bound {@code max}. + * @throws NotStrictlyPositiveException if the number of sample points {@code n} is negative. + */ + public static double[] sample(UnivariateFunction f, double min, double max, int n) + throws NumberIsTooLargeException, NotStrictlyPositiveException { + + if (n <= 0) { + throw new NotStrictlyPositiveException( + LocalizedFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES, Integer.valueOf(n)); + } + if (min >= max) { + throw new NumberIsTooLargeException(min, max, false); + } + + final double[] s = new double[n]; + final double h = (max - min) / n; + for (int i = 0; i < n; i++) { + s[i] = f.value(min + i * h); + } + return s; + } + + /** + * Convert a {@link UnivariateDifferentiableFunction} into a {@link + * DifferentiableUnivariateFunction}. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableUnivariateFunction} interface itself is deprecated + */ + @Deprecated + public static DifferentiableUnivariateFunction toDifferentiableUnivariateFunction( + final UnivariateDifferentiableFunction f) { + return new DifferentiableUnivariateFunction() { + + /** {@inheritDoc} */ + public double value(final double x) { + return f.value(x); + } + + /** {@inheritDoc} */ + public UnivariateFunction derivative() { + return new UnivariateFunction() { + /** {@inheritDoc} */ + public double value(final double x) { + return f.value(new DerivativeStructure(1, 1, 0, x)).getPartialDerivative(1); + } + }; + } + }; + } + + /** + * Convert a {@link DifferentiableUnivariateFunction} into a {@link + * UnivariateDifferentiableFunction}. + * + * <p>Note that the converted function is able to handle {@link DerivativeStructure} up to order + * one. If the function is called with higher order, a {@link NumberIsTooLargeException} is + * thrown. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableUnivariateFunction} interface itself is deprecated + */ + @Deprecated + public static UnivariateDifferentiableFunction toUnivariateDifferential( + final DifferentiableUnivariateFunction f) { + return new UnivariateDifferentiableFunction() { + + /** {@inheritDoc} */ + public double value(final double x) { + return f.value(x); + } + + /** + * {@inheritDoc} + * + * @exception NumberIsTooLargeException if derivation order is greater than 1 + */ + public DerivativeStructure value(final DerivativeStructure t) + throws NumberIsTooLargeException { + switch (t.getOrder()) { + case 0: + return new DerivativeStructure( + t.getFreeParameters(), 0, f.value(t.getValue())); + case 1: + { + final int parameters = t.getFreeParameters(); + final double[] derivatives = new double[parameters + 1]; + derivatives[0] = f.value(t.getValue()); + final double fPrime = f.derivative().value(t.getValue()); + int[] orders = new int[parameters]; + for (int i = 0; i < parameters; ++i) { + orders[i] = 1; + derivatives[i + 1] = fPrime * t.getPartialDerivative(orders); + orders[i] = 0; + } + return new DerivativeStructure(parameters, 1, derivatives); + } + default: + throw new NumberIsTooLargeException(t.getOrder(), 1, true); + } + } + }; + } + + /** + * Convert a {@link MultivariateDifferentiableFunction} into a {@link + * DifferentiableMultivariateFunction}. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableMultivariateFunction} interface itself is deprecated + */ + @Deprecated + public static DifferentiableMultivariateFunction toDifferentiableMultivariateFunction( + final MultivariateDifferentiableFunction f) { + return new DifferentiableMultivariateFunction() { + + /** {@inheritDoc} */ + public double value(final double[] x) { + return f.value(x); + } + + /** {@inheritDoc} */ + public MultivariateFunction partialDerivative(final int k) { + return new MultivariateFunction() { + /** {@inheritDoc} */ + public double value(final double[] x) { + + final int n = x.length; + + // delegate computation to underlying function + final DerivativeStructure[] dsX = new DerivativeStructure[n]; + for (int i = 0; i < n; ++i) { + if (i == k) { + dsX[i] = new DerivativeStructure(1, 1, 0, x[i]); + } else { + dsX[i] = new DerivativeStructure(1, 1, x[i]); + } + } + final DerivativeStructure y = f.value(dsX); + + // extract partial derivative + return y.getPartialDerivative(1); + } + }; + } + + /** {@inheritDoc} */ + public MultivariateVectorFunction gradient() { + return new MultivariateVectorFunction() { + /** {@inheritDoc} */ + public double[] value(final double[] x) { + + final int n = x.length; + + // delegate computation to underlying function + final DerivativeStructure[] dsX = new DerivativeStructure[n]; + for (int i = 0; i < n; ++i) { + dsX[i] = new DerivativeStructure(n, 1, i, x[i]); + } + final DerivativeStructure y = f.value(dsX); + + // extract gradient + final double[] gradient = new double[n]; + final int[] orders = new int[n]; + for (int i = 0; i < n; ++i) { + orders[i] = 1; + gradient[i] = y.getPartialDerivative(orders); + orders[i] = 0; + } + + return gradient; + } + }; + } + }; + } + + /** + * Convert a {@link DifferentiableMultivariateFunction} into a {@link + * MultivariateDifferentiableFunction}. + * + * <p>Note that the converted function is able to handle {@link DerivativeStructure} elements + * that all have the same number of free parameters and order, and with order at most 1. If the + * function is called with inconsistent numbers of free parameters or higher order, a {@link + * DimensionMismatchException} or a {@link NumberIsTooLargeException} will be thrown. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableMultivariateFunction} interface itself is deprecated + */ + @Deprecated + public static MultivariateDifferentiableFunction toMultivariateDifferentiableFunction( + final DifferentiableMultivariateFunction f) { + return new MultivariateDifferentiableFunction() { + + /** {@inheritDoc} */ + public double value(final double[] x) { + return f.value(x); + } + + /** + * {@inheritDoc} + * + * @exception NumberIsTooLargeException if derivation order is higher than 1 + * @exception DimensionMismatchException if numbers of free parameters are inconsistent + */ + public DerivativeStructure value(final DerivativeStructure[] t) + throws DimensionMismatchException, NumberIsTooLargeException { + + // check parameters and orders limits + final int parameters = t[0].getFreeParameters(); + final int order = t[0].getOrder(); + final int n = t.length; + if (order > 1) { + throw new NumberIsTooLargeException(order, 1, true); + } + + // check all elements in the array are consistent + for (int i = 0; i < n; ++i) { + if (t[i].getFreeParameters() != parameters) { + throw new DimensionMismatchException(t[i].getFreeParameters(), parameters); + } + + if (t[i].getOrder() != order) { + throw new DimensionMismatchException(t[i].getOrder(), order); + } + } + + // delegate computation to underlying function + final double[] point = new double[n]; + for (int i = 0; i < n; ++i) { + point[i] = t[i].getValue(); + } + final double value = f.value(point); + final double[] gradient = f.gradient().value(point); + + // merge value and gradient into one DerivativeStructure + final double[] derivatives = new double[parameters + 1]; + derivatives[0] = value; + final int[] orders = new int[parameters]; + for (int i = 0; i < parameters; ++i) { + orders[i] = 1; + for (int j = 0; j < n; ++j) { + derivatives[i + 1] += gradient[j] * t[j].getPartialDerivative(orders); + } + orders[i] = 0; + } + + return new DerivativeStructure(parameters, order, derivatives); + } + }; + } + + /** + * Convert a {@link MultivariateDifferentiableVectorFunction} into a {@link + * DifferentiableMultivariateVectorFunction}. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableMultivariateVectorFunction} interface itself is deprecated + */ + @Deprecated + public static DifferentiableMultivariateVectorFunction + toDifferentiableMultivariateVectorFunction( + final MultivariateDifferentiableVectorFunction f) { + return new DifferentiableMultivariateVectorFunction() { + + /** {@inheritDoc} */ + public double[] value(final double[] x) { + return f.value(x); + } + + /** {@inheritDoc} */ + public MultivariateMatrixFunction jacobian() { + return new MultivariateMatrixFunction() { + /** {@inheritDoc} */ + public double[][] value(final double[] x) { + + final int n = x.length; + + // delegate computation to underlying function + final DerivativeStructure[] dsX = new DerivativeStructure[n]; + for (int i = 0; i < n; ++i) { + dsX[i] = new DerivativeStructure(n, 1, i, x[i]); + } + final DerivativeStructure[] y = f.value(dsX); + + // extract Jacobian + final double[][] jacobian = new double[y.length][n]; + final int[] orders = new int[n]; + for (int i = 0; i < y.length; ++i) { + for (int j = 0; j < n; ++j) { + orders[j] = 1; + jacobian[i][j] = y[i].getPartialDerivative(orders); + orders[j] = 0; + } + } + + return jacobian; + } + }; + } + }; + } + + /** + * Convert a {@link DifferentiableMultivariateVectorFunction} into a {@link + * MultivariateDifferentiableVectorFunction}. + * + * <p>Note that the converted function is able to handle {@link DerivativeStructure} elements + * that all have the same number of free parameters and order, and with order at most 1. If the + * function is called with inconsistent numbers of free parameters or higher order, a {@link + * DimensionMismatchException} or a {@link NumberIsTooLargeException} will be thrown. + * + * @param f function to convert + * @return converted function + * @deprecated this conversion method is temporary in version 3.1, as the {@link + * DifferentiableMultivariateFunction} interface itself is deprecated + */ + @Deprecated + public static MultivariateDifferentiableVectorFunction + toMultivariateDifferentiableVectorFunction( + final DifferentiableMultivariateVectorFunction f) { + return new MultivariateDifferentiableVectorFunction() { + + /** {@inheritDoc} */ + public double[] value(final double[] x) { + return f.value(x); + } + + /** + * {@inheritDoc} + * + * @exception NumberIsTooLargeException if derivation order is higher than 1 + * @exception DimensionMismatchException if numbers of free parameters are inconsistent + */ + public DerivativeStructure[] value(final DerivativeStructure[] t) + throws DimensionMismatchException, NumberIsTooLargeException { + + // check parameters and orders limits + final int parameters = t[0].getFreeParameters(); + final int order = t[0].getOrder(); + final int n = t.length; + if (order > 1) { + throw new NumberIsTooLargeException(order, 1, true); + } + + // check all elements in the array are consistent + for (int i = 0; i < n; ++i) { + if (t[i].getFreeParameters() != parameters) { + throw new DimensionMismatchException(t[i].getFreeParameters(), parameters); + } + + if (t[i].getOrder() != order) { + throw new DimensionMismatchException(t[i].getOrder(), order); + } + } + + // delegate computation to underlying function + final double[] point = new double[n]; + for (int i = 0; i < n; ++i) { + point[i] = t[i].getValue(); + } + final double[] value = f.value(point); + final double[][] jacobian = f.jacobian().value(point); + + // merge value and Jacobian into a DerivativeStructure array + final DerivativeStructure[] merged = new DerivativeStructure[value.length]; + for (int k = 0; k < merged.length; ++k) { + final double[] derivatives = new double[parameters + 1]; + derivatives[0] = value[k]; + final int[] orders = new int[parameters]; + for (int i = 0; i < parameters; ++i) { + orders[i] = 1; + for (int j = 0; j < n; ++j) { + derivatives[i + 1] += + jacobian[k][j] * t[j].getPartialDerivative(orders); + } + orders[i] = 0; + } + merged[k] = new DerivativeStructure(parameters, order, derivatives); + } + + return merged; + } + }; + } +} |