diff options
author | Karl Shaffer <karlshaffer@google.com> | 2023-08-10 22:35:48 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-08-10 22:35:48 +0000 |
commit | 5484895ffd3d0c8337d159667cafc127c459f677 (patch) | |
tree | ace24ba4307d4978ee3134f7da671a77ad172da0 /src/main/java/org/apache/commons/math3/ml/neuralnet/sofm | |
parent | bbf9548f049f99fd8e5a593baae983532dd983f4 (diff) | |
parent | b3715644fba79ef08acd9a2e157d078865281767 (diff) | |
download | apache-commons-math-5484895ffd3d0c8337d159667cafc127c459f677.tar.gz |
Check-in commons-math 3.6.1 am: 1354beaf45 am: 0018f64b87 am: b3715644fb
Original change: https://android-review.googlesource.com/c/platform/external/apache-commons-math/+/2702413
Change-Id: I5ad9b2a0822d668b5b6a62933c6d4c1f0b802001
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
Diffstat (limited to 'src/main/java/org/apache/commons/math3/ml/neuralnet/sofm')
10 files changed, 793 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java new file mode 100644 index 0000000..9aa497d --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import java.util.Iterator; +import org.apache.commons.math3.ml.neuralnet.Network; + +/** + * Trainer for Kohonen's Self-Organizing Map. + * + * @since 3.3 + */ +public class KohonenTrainingTask implements Runnable { + /** SOFM to be trained. */ + private final Network net; + /** Training data. */ + private final Iterator<double[]> featuresIterator; + /** Update procedure. */ + private final KohonenUpdateAction updateAction; + + /** + * Creates a (sequential) trainer for the given network. + * + * @param net Network to be trained with the SOFM algorithm. + * @param featuresIterator Training data iterator. + * @param updateAction SOFM update procedure. + */ + public KohonenTrainingTask(Network net, + Iterator<double[]> featuresIterator, + KohonenUpdateAction updateAction) { + this.net = net; + this.featuresIterator = featuresIterator; + this.updateAction = updateAction; + } + + /** + * {@inheritDoc} + */ + public void run() { + while (featuresIterator.hasNext()) { + updateAction.update(net, featuresIterator.next()); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java new file mode 100644 index 0000000..0618aeb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import java.util.Collection; +import java.util.HashSet; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.commons.math3.analysis.function.Gaussian; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.UpdateAction; + +/** + * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> + * Kohonen's Self-Organizing Map</a>. + * <br/> + * The {@link #update(Network,double[]) update} method modifies the + * features {@code w} of the "winning" neuron and its neighbours + * according to the following rule: + * <code> + * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>) + * </code> + * where + * <ul> + * <li>α is the current <em>learning rate</em>, </li> + * <li>σ is the current <em>neighbourhood size</em>, and</li> + * <li>{@code d} is the number of links to traverse in order to reach + * the neuron from the winning neuron.</li> + * </ul> + * <br/> + * This class is thread-safe as long as the arguments passed to the + * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction, + * NeighbourhoodSizeFunction) constructor} are instances of thread-safe + * classes. + * <br/> + * Each call to the {@link #update(Network,double[]) update} method + * will increment the internal counter used to compute the current + * values for + * <ul> + * <li>the <em>learning rate</em>, and</li> + * <li>the <em>neighbourhood size</em>.</li> + * </ul> + * Consequently, the function instances that compute those values (passed + * to the constructor of this class) must take into account whether this + * class's instance will be shared by multiple threads, as this will impact + * the training process. + * + * @since 3.3 + */ +public class KohonenUpdateAction implements UpdateAction { + /** Distance function. */ + private final DistanceMeasure distance; + /** Learning factor update function. */ + private final LearningFactorFunction learningFactor; + /** Neighbourhood size update function. */ + private final NeighbourhoodSizeFunction neighbourhoodSize; + /** Number of calls to {@link #update(Network,double[])}. */ + private final AtomicLong numberOfCalls = new AtomicLong(0); + + /** + * @param distance Distance function. + * @param learningFactor Learning factor update function. + * @param neighbourhoodSize Neighbourhood size update function. + */ + public KohonenUpdateAction(DistanceMeasure distance, + LearningFactorFunction learningFactor, + NeighbourhoodSizeFunction neighbourhoodSize) { + this.distance = distance; + this.learningFactor = learningFactor; + this.neighbourhoodSize = neighbourhoodSize; + } + + /** + * {@inheritDoc} + */ + public void update(Network net, + double[] features) { + final long numCalls = numberOfCalls.incrementAndGet() - 1; + final double currentLearning = learningFactor.value(numCalls); + final Neuron best = findAndUpdateBestNeuron(net, + features, + currentLearning); + + final int currentNeighbourhood = neighbourhoodSize.value(numCalls); + // The farther away the neighbour is from the winning neuron, the + // smaller the learning rate will become. + final Gaussian neighbourhoodDecay + = new Gaussian(currentLearning, + 0, + currentNeighbourhood); + + if (currentNeighbourhood > 0) { + // Initial set of neurons only contains the winning neuron. + Collection<Neuron> neighbours = new HashSet<Neuron>(); + neighbours.add(best); + // Winning neuron must be excluded from the neighbours. + final HashSet<Neuron> exclude = new HashSet<Neuron>(); + exclude.add(best); + + int radius = 1; + do { + // Retrieve immediate neighbours of the current set of neurons. + neighbours = net.getNeighbours(neighbours, exclude); + + // Update all the neighbours. + for (Neuron n : neighbours) { + updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius)); + } + + // Add the neighbours to the exclude list so that they will + // not be update more than once per training step. + exclude.addAll(neighbours); + ++radius; + } while (radius <= currentNeighbourhood); + } + } + + /** + * Retrieves the number of calls to the {@link #update(Network,double[]) update} + * method. + * + * @return the current number of calls. + */ + public long getNumberOfCalls() { + return numberOfCalls.get(); + } + + /** + * Tries to update a neuron. + * + * @param n Neuron to be updated. + * @param features Training data. + * @param learningRate Learning factor. + * @return {@code true} if the update succeeded, {@code true} if a + * concurrent update has been detected. + */ + private boolean attemptNeuronUpdate(Neuron n, + double[] features, + double learningRate) { + final double[] expect = n.getFeatures(); + final double[] update = computeFeatures(expect, + features, + learningRate); + + return n.compareAndSetFeatures(expect, update); + } + + /** + * Atomically updates the given neuron. + * + * @param n Neuron to be updated. + * @param features Training data. + * @param learningRate Learning factor. + */ + private void updateNeighbouringNeuron(Neuron n, + double[] features, + double learningRate) { + while (true) { + if (attemptNeuronUpdate(n, features, learningRate)) { + break; + } + } + } + + /** + * Searches for the neuron whose features are closest to the given + * sample, and atomically updates its features. + * + * @param net Network. + * @param features Sample data. + * @param learningRate Current learning factor. + * @return the winning neuron. + */ + private Neuron findAndUpdateBestNeuron(Network net, + double[] features, + double learningRate) { + while (true) { + final Neuron best = MapUtils.findBest(features, net, distance); + + if (attemptNeuronUpdate(best, features, learningRate)) { + return best; + } + + // If another thread modified the state of the winning neuron, + // it may not be the best match anymore for the given training + // sample: Hence, the winner search is performed again. + } + } + + /** + * Computes the new value of the features set. + * + * @param current Current values of the features. + * @param sample Training data. + * @param learningRate Learning factor. + * @return the new values for the features. + */ + private double[] computeFeatures(double[] current, + double[] sample, + double learningRate) { + final ArrayRealVector c = new ArrayRealVector(current, false); + final ArrayRealVector s = new ArrayRealVector(sample, false); + // c + learningRate * (s - c) + return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray(); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java new file mode 100644 index 0000000..ba9d152 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +/** + * Provides the learning rate as a function of the number of calls + * already performed during the learning task. + * + * @since 3.3 + */ +public interface LearningFactorFunction { + /** + * Computes the learning rate at the current call. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + double value(long numCall); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java new file mode 100644 index 0000000..9165e82 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction; +import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction; +import org.apache.commons.math3.exception.OutOfRangeException; + +/** + * Factory for creating instances of {@link LearningFactorFunction}. + * + * @since 3.3 + */ +public class LearningFactorFunctionFactory { + /** Class contains only static methods. */ + private LearningFactorFunctionFactory() {} + + /** + * Creates an exponential decay {@link LearningFactorFunction function}. + * It will compute <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable and + * <ul> + * <li><code>a = initValue</code> + * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link LearningFactorFunction#value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @return the learning factor function. + * @throws org.apache.commons.math3.exception.OutOfRangeException + * if {@code initValue <= 0} or {@code initValue > 1}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code valueAtNumCall <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code valueAtNumCall >= initValue}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static LearningFactorFunction exponentialDecay(final double initValue, + final double valueAtNumCall, + final long numCall) { + if (initValue <= 0 || + initValue > 1) { + throw new OutOfRangeException(initValue, 0, 1); + } + + return new LearningFactorFunction() { + /** DecayFunction. */ + private final ExponentialDecayFunction decay + = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall); + + /** {@inheritDoc} */ + public double value(long n) { + return decay.value(n); + } + }; + } + + /** + * Creates an sigmoid-like {@code LearningFactorFunction function}. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link LearningFactorFunction#value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @return the learning factor function. + * @throws org.apache.commons.math3.exception.OutOfRangeException + * if {@code initValue <= 0} or {@code initValue > 1}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code slope >= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static LearningFactorFunction quasiSigmoidDecay(final double initValue, + final double slope, + final long numCall) { + if (initValue <= 0 || + initValue > 1) { + throw new OutOfRangeException(initValue, 0, 1); + } + + return new LearningFactorFunction() { + /** DecayFunction. */ + private final QuasiSigmoidDecayFunction decay + = new QuasiSigmoidDecayFunction(initValue, slope, numCall); + + /** {@inheritDoc} */ + public double value(long n) { + return decay.value(n); + } + }; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java new file mode 100644 index 0000000..68149f7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +/** + * Provides the network neighbourhood's size as a function of the + * number of calls already performed during the learning task. + * The "neighbourhood" is the set of neurons that can be reached + * by traversing at most the number of links returned by this + * function. + * + * @since 3.3 + */ +public interface NeighbourhoodSizeFunction { + /** + * Computes the neighbourhood size at the current call. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + int value(long numCall); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java new file mode 100644 index 0000000..bdbfa2f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction; +import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction; +import org.apache.commons.math3.util.FastMath; + +/** + * Factory for creating instances of {@link NeighbourhoodSizeFunction}. + * + * @since 3.3 + */ +public class NeighbourhoodSizeFunctionFactory { + /** Class contains only static methods. */ + private NeighbourhoodSizeFunctionFactory() {} + + /** + * Creates an exponential decay {@link NeighbourhoodSizeFunction function}. + * It will compute <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable and + * <ul> + * <li><code>a = initValue</code> + * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link NeighbourhoodSizeFunction#value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @return the neighbourhood size function. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code initValue <= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code valueAtNumCall <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code valueAtNumCall >= initValue}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static NeighbourhoodSizeFunction exponentialDecay(final double initValue, + final double valueAtNumCall, + final long numCall) { + return new NeighbourhoodSizeFunction() { + /** DecayFunction. */ + private final ExponentialDecayFunction decay + = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall); + + /** {@inheritDoc} */ + public int value(long n) { + return (int) FastMath.rint(decay.value(n)); + } + }; + } + + /** + * Creates an sigmoid-like {@code NeighbourhoodSizeFunction function}. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link NeighbourhoodSizeFunction#value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @return the neighbourhood size function. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code initValue <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code slope >= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static NeighbourhoodSizeFunction quasiSigmoidDecay(final double initValue, + final double slope, + final long numCall) { + return new NeighbourhoodSizeFunction() { + /** DecayFunction. */ + private final QuasiSigmoidDecayFunction decay + = new QuasiSigmoidDecayFunction(initValue, slope, numCall); + + /** {@inheritDoc} */ + public int value(long n) { + return (int) FastMath.rint(decay.value(n)); + } + }; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java new file mode 100644 index 0000000..60c3c61 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Self Organizing Feature Map. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java new file mode 100644 index 0000000..19e7380 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.util.FastMath; + +/** + * Exponential decay function: <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable. + * <br/> + * Class is immutable. + * + * @since 3.3 + */ +public class ExponentialDecayFunction { + /** Factor {@code a}. */ + private final double a; + /** Factor {@code 1 / b}. */ + private final double oneOverB; + + /** + * Creates an instance. It will be such that + * <ul> + * <li>{@code a = initValue}</li> + * <li>{@code b = -numCall / ln(valueAtNumCall / initValue)}</li> + * </ul> + * + * @param initValue Initial value, i.e. {@link #value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @throws NotStrictlyPositiveException if {@code initValue <= 0}. + * @throws NotStrictlyPositiveException if {@code valueAtNumCall <= 0}. + * @throws NumberIsTooLargeException if {@code valueAtNumCall >= initValue}. + * @throws NotStrictlyPositiveException if {@code numCall <= 0}. + */ + public ExponentialDecayFunction(double initValue, + double valueAtNumCall, + long numCall) { + if (initValue <= 0) { + throw new NotStrictlyPositiveException(initValue); + } + if (valueAtNumCall <= 0) { + throw new NotStrictlyPositiveException(valueAtNumCall); + } + if (valueAtNumCall >= initValue) { + throw new NumberIsTooLargeException(valueAtNumCall, initValue, false); + } + if (numCall <= 0) { + throw new NotStrictlyPositiveException(numCall); + } + + a = initValue; + oneOverB = -FastMath.log(valueAtNumCall / initValue) / numCall; + } + + /** + * Computes <code>a e<sup>-numCall / b</sup></code>. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + public double value(long numCall) { + return a * FastMath.exp(-numCall * oneOverB); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java new file mode 100644 index 0000000..3d35c17 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.analysis.function.Logistic; + +/** + * Decay function whose shape is similar to a sigmoid. + * <br/> + * Class is immutable. + * + * @since 3.3 + */ +public class QuasiSigmoidDecayFunction { + /** Sigmoid. */ + private final Logistic sigmoid; + /** See {@link #value(long)}. */ + private final double scale; + + /** + * Creates an instance. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. {@link #value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @throws NotStrictlyPositiveException if {@code initValue <= 0}. + * @throws NumberIsTooLargeException if {@code slope >= 0}. + * @throws NotStrictlyPositiveException if {@code numCall <= 0}. + */ + public QuasiSigmoidDecayFunction(double initValue, + double slope, + long numCall) { + if (initValue <= 0) { + throw new NotStrictlyPositiveException(initValue); + } + if (slope >= 0) { + throw new NumberIsTooLargeException(slope, 0, false); + } + if (numCall <= 1) { + throw new NotStrictlyPositiveException(numCall); + } + + final double k = initValue; + final double m = numCall; + final double b = 4 * slope / initValue; + final double q = 1; + final double a = 0; + final double n = 1; + sigmoid = new Logistic(k, m, b, q, a, n); + + final double y0 = sigmoid.value(0); + scale = k / y0; + } + + /** + * Computes the value of the learning factor. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + public double value(long numCall) { + return scale * sigmoid.value(numCall); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java new file mode 100644 index 0000000..5078ed2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Miscellaneous utilities. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; |