diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/optimization/BaseMultivariateVectorMultiStartOptimizer.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/optimization/BaseMultivariateVectorMultiStartOptimizer.java | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/optimization/BaseMultivariateVectorMultiStartOptimizer.java b/src/main/java/org/apache/commons/math3/optimization/BaseMultivariateVectorMultiStartOptimizer.java new file mode 100644 index 0000000..7b6523a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/optimization/BaseMultivariateVectorMultiStartOptimizer.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.optimization; + +import org.apache.commons.math3.analysis.MultivariateVectorFunction; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.random.RandomVectorGenerator; + +import java.util.Arrays; +import java.util.Comparator; + +/** + * Base class for all implementations of a multi-start optimizer. + * + * <p>This interface is mainly intended to enforce the internal coherence of Commons-Math. Users of + * the API are advised to base their code on {@link + * DifferentiableMultivariateVectorMultiStartOptimizer}. + * + * @param <FUNC> Type of the objective function to be optimized. + * @deprecated As of 3.1 (to be removed in 4.0). + * @since 3.0 + */ +@Deprecated +public class BaseMultivariateVectorMultiStartOptimizer<FUNC extends MultivariateVectorFunction> + implements BaseMultivariateVectorOptimizer<FUNC> { + /** Underlying classical optimizer. */ + private final BaseMultivariateVectorOptimizer<FUNC> optimizer; + + /** Maximal number of evaluations allowed. */ + private int maxEvaluations; + + /** Number of evaluations already performed for all starts. */ + private int totalEvaluations; + + /** Number of starts to go. */ + private int starts; + + /** Random generator for multi-start. */ + private RandomVectorGenerator generator; + + /** Found optima. */ + private PointVectorValuePair[] optima; + + /** + * Create a multi-start optimizer from a single-start optimizer. + * + * @param optimizer Single-start optimizer to wrap. + * @param starts Number of starts to perform. If {@code starts == 1}, the {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} will + * return the same solution as {@code optimizer} would. + * @param generator Random vector generator to use for restarts. + * @throws NullArgumentException if {@code optimizer} or {@code generator} is {@code null}. + * @throws NotStrictlyPositiveException if {@code starts < 1}. + */ + protected BaseMultivariateVectorMultiStartOptimizer( + final BaseMultivariateVectorOptimizer<FUNC> optimizer, + final int starts, + final RandomVectorGenerator generator) { + if (optimizer == null || generator == null) { + throw new NullArgumentException(); + } + if (starts < 1) { + throw new NotStrictlyPositiveException(starts); + } + + this.optimizer = optimizer; + this.starts = starts; + this.generator = generator; + } + + /** + * Get all the optima found during the last call to {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize}. The optimizer + * stores all the optima found during a set of restarts. The {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method returns + * the best point only. This method returns all the points found at the end of each starts, + * including the best one already returned by the {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method. <br> + * The returned array as one element for each start as specified in the constructor. It is + * ordered with the results from the runs that did converge first, sorted from best to worst + * objective value (i.e. in ascending order if minimizing and in descending order if + * maximizing), followed by and null elements corresponding to the runs that did not converge. + * This means all elements will be null if the {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method did + * throw a {@link ConvergenceException}). This also means that if the first element is not + * {@code null}, it is the best point found across all starts. + * + * @return array containing the optima + * @throws MathIllegalStateException if {@link + * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} has not + * been called. + */ + public PointVectorValuePair[] getOptima() { + if (optima == null) { + throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET); + } + return optima.clone(); + } + + /** {@inheritDoc} */ + public int getMaxEvaluations() { + return maxEvaluations; + } + + /** {@inheritDoc} */ + public int getEvaluations() { + return totalEvaluations; + } + + /** {@inheritDoc} */ + public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() { + return optimizer.getConvergenceChecker(); + } + + /** {@inheritDoc} */ + public PointVectorValuePair optimize( + int maxEval, final FUNC f, double[] target, double[] weights, double[] startPoint) { + maxEvaluations = maxEval; + RuntimeException lastException = null; + optima = new PointVectorValuePair[starts]; + totalEvaluations = 0; + + // Multi-start loop. + for (int i = 0; i < starts; ++i) { + + // CHECKSTYLE: stop IllegalCatch + try { + optima[i] = + optimizer.optimize( + maxEval - totalEvaluations, + f, + target, + weights, + i == 0 ? startPoint : generator.nextVector()); + } catch (ConvergenceException oe) { + optima[i] = null; + } catch (RuntimeException mue) { + lastException = mue; + optima[i] = null; + } + // CHECKSTYLE: resume IllegalCatch + + totalEvaluations += optimizer.getEvaluations(); + } + + sortPairs(target, weights); + + if (optima[0] == null) { + throw lastException; // cannot be null if starts >=1 + } + + // Return the found point given the best objective function value. + return optima[0]; + } + + /** + * Sort the optima from best to worst, followed by {@code null} elements. + * + * @param target Target value for the objective functions at optimum. + * @param weights Weights for the least-squares cost computation. + */ + private void sortPairs(final double[] target, final double[] weights) { + Arrays.sort( + optima, + new Comparator<PointVectorValuePair>() { + /** {@inheritDoc} */ + public int compare( + final PointVectorValuePair o1, final PointVectorValuePair o2) { + if (o1 == null) { + return (o2 == null) ? 0 : 1; + } else if (o2 == null) { + return -1; + } + return Double.compare(weightedResidual(o1), weightedResidual(o2)); + } + + private double weightedResidual(final PointVectorValuePair pv) { + final double[] value = pv.getValueRef(); + double sum = 0; + for (int i = 0; i < value.length; ++i) { + final double ri = value[i] - target[i]; + sum += weights[i] * ri * ri; + } + return sum; + } + }); + } +} |