summaryrefslogtreecommitdiff
path: root/src/main/java/org/apache/commons/math3/ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/org/apache/commons/math3/ml')
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java53
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java60
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java32
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java80
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java233
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java86
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java426
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java565
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java135
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java122
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java69
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java20
-rw-r--r--src/main/java/org/apache/commons/math3/ml/clustering/package-info.java20
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java46
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java38
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java41
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java48
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java38
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java38
-rw-r--r--src/main/java/org/apache/commons/math3/ml/distance/package-info.java20
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java32
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java114
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java326
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java499
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java272
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java38
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java34
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java238
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java59
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java225
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java34
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java117
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java37
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java107
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java83
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java87
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java628
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java83
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java105
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java38
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java34
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java76
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java97
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java91
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java209
-rw-r--r--src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java22
-rw-r--r--src/main/java/org/apache/commons/math3/ml/package-info.java18
52 files changed, 5983 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java
new file mode 100644
index 0000000..5cfc7bc
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+/**
+ * A Cluster used by centroid-based clustering algorithms.
+ * <p>
+ * Defines additionally a cluster center which may not necessarily be a member
+ * of the original data set.
+ *
+ * @param <T> the type of points that can be clustered
+ * @since 3.2
+ */
+public class CentroidCluster<T extends Clusterable> extends Cluster<T> {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -3075288519071812288L;
+
+ /** Center of the cluster. */
+ private final Clusterable center;
+
+ /**
+ * Build a cluster centered at a specified point.
+ * @param center the point which is to be the center of this cluster
+ */
+ public CentroidCluster(final Clusterable center) {
+ super();
+ this.center = center;
+ }
+
+ /**
+ * Get the point chosen to be the center of this cluster.
+ * @return chosen cluster center
+ */
+ public Clusterable getCenter() {
+ return center;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java
new file mode 100644
index 0000000..fa6df94
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Cluster holding a set of {@link Clusterable} points.
+ * @param <T> the type of points that can be clustered
+ * @since 3.2
+ */
+public class Cluster<T extends Clusterable> implements Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -3442297081515880464L;
+
+ /** The points contained in this cluster. */
+ private final List<T> points;
+
+ /**
+ * Build a cluster centered at a specified point.
+ */
+ public Cluster() {
+ points = new ArrayList<T>();
+ }
+
+ /**
+ * Add a point to this cluster.
+ * @param point point to add
+ */
+ public void addPoint(final T point) {
+ points.add(point);
+ }
+
+ /**
+ * Get the points contained in the cluster.
+ * @return points contained in the cluster
+ */
+ public List<T> getPoints() {
+ return points;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java
new file mode 100644
index 0000000..e712eb7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+/**
+ * Interface for n-dimensional points that can be clustered together.
+ * @since 3.2
+ */
+public interface Clusterable {
+
+ /**
+ * Gets the n-dimensional point.
+ *
+ * @return the point array
+ */
+ double[] getPoint();
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java
new file mode 100644
index 0000000..30e38c6
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+
+/**
+ * Base class for clustering algorithms.
+ *
+ * @param <T> the type of points that can be clustered
+ * @since 3.2
+ */
+public abstract class Clusterer<T extends Clusterable> {
+
+ /** The distance measure to use. */
+ private DistanceMeasure measure;
+
+ /**
+ * Build a new clusterer with the given {@link DistanceMeasure}.
+ *
+ * @param measure the distance measure to use
+ */
+ protected Clusterer(final DistanceMeasure measure) {
+ this.measure = measure;
+ }
+
+ /**
+ * Perform a cluster analysis on the given set of {@link Clusterable} instances.
+ *
+ * @param points the set of {@link Clusterable} instances
+ * @return a {@link List} of clusters
+ * @throws MathIllegalArgumentException if points are null or the number of
+ * data points is not compatible with this clusterer
+ * @throws ConvergenceException if the algorithm has not yet converged after
+ * the maximum number of iterations has been exceeded
+ */
+ public abstract List<? extends Cluster<T>> cluster(Collection<T> points)
+ throws MathIllegalArgumentException, ConvergenceException;
+
+ /**
+ * Returns the {@link DistanceMeasure} instance used by this clusterer.
+ *
+ * @return the distance measure
+ */
+ public DistanceMeasure getDistanceMeasure() {
+ return measure;
+ }
+
+ /**
+ * Calculates the distance between two {@link Clusterable} instances
+ * with the configured {@link DistanceMeasure}.
+ *
+ * @param p1 the first clusterable
+ * @param p2 the second clusterable
+ * @return the distance between the two clusterables
+ */
+ protected double distance(final Clusterable p1, final Clusterable p2) {
+ return measure.compute(p1.getPoint(), p2.getPoint());
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java
new file mode 100644
index 0000000..ce3d5cd
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.commons.math3.exception.NotPositiveException;
+import org.apache.commons.math3.exception.NullArgumentException;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.util.MathUtils;
+
+/**
+ * DBSCAN (density-based spatial clustering of applications with noise) algorithm.
+ * <p>
+ * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e.
+ * a point p is density connected to another point q, if there exists a chain of
+ * points p<sub>i</sub>, with i = 1 .. n and p<sub>1</sub> = p and p<sub>n</sub> = q,
+ * such that each pair &lt;p<sub>i</sub>, p<sub>i+1</sub>&gt; is directly density-reachable.
+ * A point q is directly density-reachable from point p if it is in the &epsilon;-neighborhood
+ * of this point.
+ * <p>
+ * Any point that is not density-reachable from a formed cluster is treated as noise, and
+ * will thus not be present in the result.
+ * <p>
+ * The algorithm requires two parameters:
+ * <ul>
+ * <li>eps: the distance that defines the &epsilon;-neighborhood of a point
+ * <li>minPoints: the minimum number of density-connected points required to form a cluster
+ * </ul>
+ *
+ * @param <T> type of the points to cluster
+ * @see <a href="http://en.wikipedia.org/wiki/DBSCAN">DBSCAN (wikipedia)</a>
+ * @see <a href="http://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf">
+ * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise</a>
+ * @since 3.2
+ */
+public class DBSCANClusterer<T extends Clusterable> extends Clusterer<T> {
+
+ /** Maximum radius of the neighborhood to be considered. */
+ private final double eps;
+
+ /** Minimum number of points needed for a cluster. */
+ private final int minPts;
+
+ /** Status of a point during the clustering process. */
+ private enum PointStatus {
+ /** The point has is considered to be noise. */
+ NOISE,
+ /** The point is already part of a cluster. */
+ PART_OF_CLUSTER
+ }
+
+ /**
+ * Creates a new instance of a DBSCANClusterer.
+ * <p>
+ * The euclidean distance will be used as default distance measure.
+ *
+ * @param eps maximum radius of the neighborhood to be considered
+ * @param minPts minimum number of points needed for a cluster
+ * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
+ */
+ public DBSCANClusterer(final double eps, final int minPts)
+ throws NotPositiveException {
+ this(eps, minPts, new EuclideanDistance());
+ }
+
+ /**
+ * Creates a new instance of a DBSCANClusterer.
+ *
+ * @param eps maximum radius of the neighborhood to be considered
+ * @param minPts minimum number of points needed for a cluster
+ * @param measure the distance measure to use
+ * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0}
+ */
+ public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure)
+ throws NotPositiveException {
+ super(measure);
+
+ if (eps < 0.0d) {
+ throw new NotPositiveException(eps);
+ }
+ if (minPts < 0) {
+ throw new NotPositiveException(minPts);
+ }
+ this.eps = eps;
+ this.minPts = minPts;
+ }
+
+ /**
+ * Returns the maximum radius of the neighborhood to be considered.
+ * @return maximum radius of the neighborhood
+ */
+ public double getEps() {
+ return eps;
+ }
+
+ /**
+ * Returns the minimum number of points needed for a cluster.
+ * @return minimum number of points needed for a cluster
+ */
+ public int getMinPts() {
+ return minPts;
+ }
+
+ /**
+ * Performs DBSCAN cluster analysis.
+ *
+ * @param points the points to cluster
+ * @return the list of clusters
+ * @throws NullArgumentException if the data points are null
+ */
+ @Override
+ public List<Cluster<T>> cluster(final Collection<T> points) throws NullArgumentException {
+
+ // sanity checks
+ MathUtils.checkNotNull(points);
+
+ final List<Cluster<T>> clusters = new ArrayList<Cluster<T>>();
+ final Map<Clusterable, PointStatus> visited = new HashMap<Clusterable, PointStatus>();
+
+ for (final T point : points) {
+ if (visited.get(point) != null) {
+ continue;
+ }
+ final List<T> neighbors = getNeighbors(point, points);
+ if (neighbors.size() >= minPts) {
+ // DBSCAN does not care about center points
+ final Cluster<T> cluster = new Cluster<T>();
+ clusters.add(expandCluster(cluster, point, neighbors, points, visited));
+ } else {
+ visited.put(point, PointStatus.NOISE);
+ }
+ }
+
+ return clusters;
+ }
+
+ /**
+ * Expands the cluster to include density-reachable items.
+ *
+ * @param cluster Cluster to expand
+ * @param point Point to add to cluster
+ * @param neighbors List of neighbors
+ * @param points the data set
+ * @param visited the set of already visited points
+ * @return the expanded cluster
+ */
+ private Cluster<T> expandCluster(final Cluster<T> cluster,
+ final T point,
+ final List<T> neighbors,
+ final Collection<T> points,
+ final Map<Clusterable, PointStatus> visited) {
+ cluster.addPoint(point);
+ visited.put(point, PointStatus.PART_OF_CLUSTER);
+
+ List<T> seeds = new ArrayList<T>(neighbors);
+ int index = 0;
+ while (index < seeds.size()) {
+ final T current = seeds.get(index);
+ PointStatus pStatus = visited.get(current);
+ // only check non-visited points
+ if (pStatus == null) {
+ final List<T> currentNeighbors = getNeighbors(current, points);
+ if (currentNeighbors.size() >= minPts) {
+ seeds = merge(seeds, currentNeighbors);
+ }
+ }
+
+ if (pStatus != PointStatus.PART_OF_CLUSTER) {
+ visited.put(current, PointStatus.PART_OF_CLUSTER);
+ cluster.addPoint(current);
+ }
+
+ index++;
+ }
+ return cluster;
+ }
+
+ /**
+ * Returns a list of density-reachable neighbors of a {@code point}.
+ *
+ * @param point the point to look for
+ * @param points possible neighbors
+ * @return the List of neighbors
+ */
+ private List<T> getNeighbors(final T point, final Collection<T> points) {
+ final List<T> neighbors = new ArrayList<T>();
+ for (final T neighbor : points) {
+ if (point != neighbor && distance(neighbor, point) <= eps) {
+ neighbors.add(neighbor);
+ }
+ }
+ return neighbors;
+ }
+
+ /**
+ * Merges two lists together.
+ *
+ * @param one first list
+ * @param two second list
+ * @return merged lists
+ */
+ private List<T> merge(final List<T> one, final List<T> two) {
+ final Set<T> oneSet = new HashSet<T>(one);
+ for (T item : two) {
+ if (!oneSet.contains(item)) {
+ one.add(item);
+ }
+ }
+ return one;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java
new file mode 100644
index 0000000..4fb31f7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * A simple implementation of {@link Clusterable} for points with double coordinates.
+ * @since 3.2
+ */
+public class DoublePoint implements Clusterable, Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = 3946024775784901369L;
+
+ /** Point coordinates. */
+ private final double[] point;
+
+ /**
+ * Build an instance wrapping an double array.
+ * <p>
+ * The wrapped array is referenced, it is <em>not</em> copied.
+ *
+ * @param point the n-dimensional point in double space
+ */
+ public DoublePoint(final double[] point) {
+ this.point = point;
+ }
+
+ /**
+ * Build an instance wrapping an integer array.
+ * <p>
+ * The wrapped array is copied to an internal double array.
+ *
+ * @param point the n-dimensional point in integer space
+ */
+ public DoublePoint(final int[] point) {
+ this.point = new double[point.length];
+ for ( int i = 0; i < point.length; i++) {
+ this.point[i] = point[i];
+ }
+ }
+
+ /** {@inheritDoc} */
+ public double[] getPoint() {
+ return point;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(final Object other) {
+ if (!(other instanceof DoublePoint)) {
+ return false;
+ }
+ return Arrays.equals(point, ((DoublePoint) other).point);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(point);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return Arrays.toString(point);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java
new file mode 100644
index 0000000..5f89934
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.exception.MathIllegalStateException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.random.JDKRandomGenerator;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+import org.apache.commons.math3.util.MathUtils;
+
+/**
+ * Fuzzy K-Means clustering algorithm.
+ * <p>
+ * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
+ * major difference that a single data point is not uniquely assigned to a single cluster.
+ * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership
+ * to the cluster j.
+ * <p>
+ * The algorithm then tries to minimize the objective function:
+ * <pre>
+ * J = &#8721;<sub>i=1..C</sub>&#8721;<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup>
+ * </pre>
+ * with d<sub>ik</sub> being the distance between data point i and the cluster center k.
+ * <p>
+ * The algorithm requires two parameters:
+ * <ul>
+ * <li>k: the number of clusters
+ * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
+ * </ul>
+ * Additional, optional parameters:
+ * <ul>
+ * <li>maxIterations: the maximum number of iterations
+ * <li>epsilon: the convergence criteria, default is 1e-3
+ * </ul>
+ * <p>
+ * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
+ * of the initial cluster centers.
+ *
+ * @param <T> type of the points to cluster
+ * @since 3.3
+ */
+public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
+
+ /** The default value for the convergence criteria. */
+ private static final double DEFAULT_EPSILON = 1e-3;
+
+ /** The number of clusters. */
+ private final int k;
+
+ /** The maximum number of iterations. */
+ private final int maxIterations;
+
+ /** The fuzziness factor. */
+ private final double fuzziness;
+
+ /** The convergence criteria. */
+ private final double epsilon;
+
+ /** Random generator for choosing initial centers. */
+ private final RandomGenerator random;
+
+ /** The membership matrix. */
+ private double[][] membershipMatrix;
+
+ /** The list of points used in the last call to {@link #cluster(Collection)}. */
+ private List<T> points;
+
+ /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
+ private List<CentroidCluster<T>> clusters;
+
+ /**
+ * Creates a new instance of a FuzzyKMeansClusterer.
+ * <p>
+ * The euclidean distance will be used as default distance measure.
+ *
+ * @param k the number of clusters to split the data into
+ * @param fuzziness the fuzziness factor, must be &gt; 1.0
+ * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
+ */
+ public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException {
+ this(k, fuzziness, -1, new EuclideanDistance());
+ }
+
+ /**
+ * Creates a new instance of a FuzzyKMeansClusterer.
+ *
+ * @param k the number of clusters to split the data into
+ * @param fuzziness the fuzziness factor, must be &gt; 1.0
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
+ */
+ public FuzzyKMeansClusterer(final int k, final double fuzziness,
+ final int maxIterations, final DistanceMeasure measure)
+ throws NumberIsTooSmallException {
+ this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator());
+ }
+
+ /**
+ * Creates a new instance of a FuzzyKMeansClusterer.
+ *
+ * @param k the number of clusters to split the data into
+ * @param fuzziness the fuzziness factor, must be &gt; 1.0
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @param epsilon the convergence criteria (default is 1e-3)
+ * @param random random generator to use for choosing initial centers
+ * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
+ */
+ public FuzzyKMeansClusterer(final int k, final double fuzziness,
+ final int maxIterations, final DistanceMeasure measure,
+ final double epsilon, final RandomGenerator random)
+ throws NumberIsTooSmallException {
+
+ super(measure);
+
+ if (fuzziness <= 1.0d) {
+ throw new NumberIsTooSmallException(fuzziness, 1.0, false);
+ }
+ this.k = k;
+ this.fuzziness = fuzziness;
+ this.maxIterations = maxIterations;
+ this.epsilon = epsilon;
+ this.random = random;
+
+ this.membershipMatrix = null;
+ this.points = null;
+ this.clusters = null;
+ }
+
+ /**
+ * Return the number of clusters this instance will use.
+ * @return the number of clusters
+ */
+ public int getK() {
+ return k;
+ }
+
+ /**
+ * Returns the fuzziness factor used by this instance.
+ * @return the fuzziness factor
+ */
+ public double getFuzziness() {
+ return fuzziness;
+ }
+
+ /**
+ * Returns the maximum number of iterations this instance will use.
+ * @return the maximum number of iterations, or -1 if no maximum is set
+ */
+ public int getMaxIterations() {
+ return maxIterations;
+ }
+
+ /**
+ * Returns the convergence criteria used by this instance.
+ * @return the convergence criteria
+ */
+ public double getEpsilon() {
+ return epsilon;
+ }
+
+ /**
+ * Returns the random generator this instance will use.
+ * @return the random generator
+ */
+ public RandomGenerator getRandomGenerator() {
+ return random;
+ }
+
+ /**
+ * Returns the {@code nxk} membership matrix, where {@code n} is the number
+ * of data points and {@code k} the number of clusters.
+ * <p>
+ * The element U<sub>i,j</sub> represents the membership value for data point {@code i}
+ * to cluster {@code j}.
+ *
+ * @return the membership matrix
+ * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
+ */
+ public RealMatrix getMembershipMatrix() {
+ if (membershipMatrix == null) {
+ throw new MathIllegalStateException();
+ }
+ return MatrixUtils.createRealMatrix(membershipMatrix);
+ }
+
+ /**
+ * Returns an unmodifiable list of the data points used in the last
+ * call to {@link #cluster(Collection)}.
+ * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
+ * not been called before.
+ */
+ public List<T> getDataPoints() {
+ return points;
+ }
+
+ /**
+ * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
+ * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
+ * not been called before.
+ */
+ public List<CentroidCluster<T>> getClusters() {
+ return clusters;
+ }
+
+ /**
+ * Get the value of the objective function.
+ * @return the objective function evaluation as double value
+ * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
+ */
+ public double getObjectiveFunctionValue() {
+ if (points == null || clusters == null) {
+ throw new MathIllegalStateException();
+ }
+
+ int i = 0;
+ double objFunction = 0.0;
+ for (final T point : points) {
+ int j = 0;
+ for (final CentroidCluster<T> cluster : clusters) {
+ final double dist = distance(point, cluster.getCenter());
+ objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness);
+ j++;
+ }
+ i++;
+ }
+ return objFunction;
+ }
+
+ /**
+ * Performs Fuzzy K-Means cluster analysis.
+ *
+ * @param dataPoints the points to cluster
+ * @return the list of clusters
+ * @throws MathIllegalArgumentException if the data points are null or the number
+ * of clusters is larger than the number of data points
+ */
+ @Override
+ public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints)
+ throws MathIllegalArgumentException {
+
+ // sanity checks
+ MathUtils.checkNotNull(dataPoints);
+
+ final int size = dataPoints.size();
+
+ // number of clusters has to be smaller or equal the number of data points
+ if (size < k) {
+ throw new NumberIsTooSmallException(size, k, false);
+ }
+
+ // copy the input collection to an unmodifiable list with indexed access
+ points = Collections.unmodifiableList(new ArrayList<T>(dataPoints));
+ clusters = new ArrayList<CentroidCluster<T>>();
+ membershipMatrix = new double[size][k];
+ final double[][] oldMatrix = new double[size][k];
+
+ // if no points are provided, return an empty list of clusters
+ if (size == 0) {
+ return clusters;
+ }
+
+ initializeMembershipMatrix();
+
+ // there is at least one point
+ final int pointDimension = points.get(0).getPoint().length;
+ for (int i = 0; i < k; i++) {
+ clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension])));
+ }
+
+ int iteration = 0;
+ final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
+ double difference = 0.0;
+
+ do {
+ saveMembershipMatrix(oldMatrix);
+ updateClusterCenters();
+ updateMembershipMatrix();
+ difference = calculateMaxMembershipChange(oldMatrix);
+ } while (difference > epsilon && ++iteration < max);
+
+ return clusters;
+ }
+
+ /**
+ * Update the cluster centers.
+ */
+ private void updateClusterCenters() {
+ int j = 0;
+ final List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(k);
+ for (final CentroidCluster<T> cluster : clusters) {
+ final Clusterable center = cluster.getCenter();
+ int i = 0;
+ double[] arr = new double[center.getPoint().length];
+ double sum = 0.0;
+ for (final T point : points) {
+ final double u = FastMath.pow(membershipMatrix[i][j], fuzziness);
+ final double[] pointArr = point.getPoint();
+ for (int idx = 0; idx < arr.length; idx++) {
+ arr[idx] += u * pointArr[idx];
+ }
+ sum += u;
+ i++;
+ }
+ MathArrays.scaleInPlace(1.0 / sum, arr);
+ newClusters.add(new CentroidCluster<T>(new DoublePoint(arr)));
+ j++;
+ }
+ clusters.clear();
+ clusters = newClusters;
+ }
+
+ /**
+ * Updates the membership matrix and assigns the points to the cluster with
+ * the highest membership.
+ */
+ private void updateMembershipMatrix() {
+ for (int i = 0; i < points.size(); i++) {
+ final T point = points.get(i);
+ double maxMembership = Double.MIN_VALUE;
+ int newCluster = -1;
+ for (int j = 0; j < clusters.size(); j++) {
+ double sum = 0.0;
+ final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));
+
+ if (distA != 0.0) {
+ for (final CentroidCluster<T> c : clusters) {
+ final double distB = FastMath.abs(distance(point, c.getCenter()));
+ if (distB == 0.0) {
+ sum = Double.POSITIVE_INFINITY;
+ break;
+ }
+ sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
+ }
+ }
+
+ double membership;
+ if (sum == 0.0) {
+ membership = 1.0;
+ } else if (sum == Double.POSITIVE_INFINITY) {
+ membership = 0.0;
+ } else {
+ membership = 1.0 / sum;
+ }
+ membershipMatrix[i][j] = membership;
+
+ if (membershipMatrix[i][j] > maxMembership) {
+ maxMembership = membershipMatrix[i][j];
+ newCluster = j;
+ }
+ }
+ clusters.get(newCluster).addPoint(point);
+ }
+ }
+
+ /**
+ * Initialize the membership matrix with random values.
+ */
+ private void initializeMembershipMatrix() {
+ for (int i = 0; i < points.size(); i++) {
+ for (int j = 0; j < k; j++) {
+ membershipMatrix[i][j] = random.nextDouble();
+ }
+ membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
+ }
+ }
+
+ /**
+ * Calculate the maximum element-by-element change of the membership matrix
+ * for the current iteration.
+ *
+ * @param matrix the membership matrix of the previous iteration
+ * @return the maximum membership matrix change
+ */
+ private double calculateMaxMembershipChange(final double[][] matrix) {
+ double maxMembership = 0.0;
+ for (int i = 0; i < points.size(); i++) {
+ for (int j = 0; j < clusters.size(); j++) {
+ double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]);
+ maxMembership = FastMath.max(v, maxMembership);
+ }
+ }
+ return maxMembership;
+ }
+
+ /**
+ * Copy the membership matrix into the provided matrix.
+ *
+ * @param matrix the place to store the membership matrix
+ */
+ private void saveMembershipMatrix(final double[][] matrix) {
+ for (int i = 0; i < points.size(); i++) {
+ System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
+ }
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java
new file mode 100644
index 0000000..2e57fac
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java
@@ -0,0 +1,565 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.util.LocalizedFormats;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.random.JDKRandomGenerator;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.stat.descriptive.moment.Variance;
+import org.apache.commons.math3.util.MathUtils;
+
+/**
+ * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
+ * @param <T> type of the points to cluster
+ * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
+ * @since 3.2
+ */
+public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
+
+ /** Strategies to use for replacing an empty cluster. */
+ public enum EmptyClusterStrategy {
+
+ /** Split the cluster with largest distance variance. */
+ LARGEST_VARIANCE,
+
+ /** Split the cluster with largest number of points. */
+ LARGEST_POINTS_NUMBER,
+
+ /** Create a cluster around the point farthest from its centroid. */
+ FARTHEST_POINT,
+
+ /** Generate an error. */
+ ERROR
+
+ }
+
+ /** The number of clusters. */
+ private final int k;
+
+ /** The maximum number of iterations. */
+ private final int maxIterations;
+
+ /** Random generator for choosing initial centers. */
+ private final RandomGenerator random;
+
+ /** Selected strategy for empty clusters. */
+ private final EmptyClusterStrategy emptyStrategy;
+
+ /** Build a clusterer.
+ * <p>
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ * <p>
+ * The euclidean distance will be used as default distance measure.
+ *
+ * @param k the number of clusters to split the data into
+ */
+ public KMeansPlusPlusClusterer(final int k) {
+ this(k, -1);
+ }
+
+ /** Build a clusterer.
+ * <p>
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ * <p>
+ * The euclidean distance will be used as default distance measure.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
+ this(k, maxIterations, new EuclideanDistance());
+ }
+
+ /** Build a clusterer.
+ * <p>
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
+ this(k, maxIterations, measure, new JDKRandomGenerator());
+ }
+
+ /** Build a clusterer.
+ * <p>
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @param random random generator to use for choosing initial centers
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
+ final DistanceMeasure measure,
+ final RandomGenerator random) {
+ this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
+ }
+
+ /** Build a clusterer.
+ *
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm for.
+ * If negative, no maximum will be used.
+ * @param measure the distance measure to use
+ * @param random random generator to use for choosing initial centers
+ * @param emptyStrategy strategy to use for handling empty clusters that
+ * may appear during algorithm iterations
+ */
+ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
+ final DistanceMeasure measure,
+ final RandomGenerator random,
+ final EmptyClusterStrategy emptyStrategy) {
+ super(measure);
+ this.k = k;
+ this.maxIterations = maxIterations;
+ this.random = random;
+ this.emptyStrategy = emptyStrategy;
+ }
+
+ /**
+ * Return the number of clusters this instance will use.
+ * @return the number of clusters
+ */
+ public int getK() {
+ return k;
+ }
+
+ /**
+ * Returns the maximum number of iterations this instance will use.
+ * @return the maximum number of iterations, or -1 if no maximum is set
+ */
+ public int getMaxIterations() {
+ return maxIterations;
+ }
+
+ /**
+ * Returns the random generator this instance will use.
+ * @return the random generator
+ */
+ public RandomGenerator getRandomGenerator() {
+ return random;
+ }
+
+ /**
+ * Returns the {@link EmptyClusterStrategy} used by this instance.
+ * @return the {@link EmptyClusterStrategy}
+ */
+ public EmptyClusterStrategy getEmptyClusterStrategy() {
+ return emptyStrategy;
+ }
+
+ /**
+ * Runs the K-means++ clustering algorithm.
+ *
+ * @param points the points to cluster
+ * @return a list of clusters containing the points
+ * @throws MathIllegalArgumentException if the data points are null or the number
+ * of clusters is larger than the number of data points
+ * @throws ConvergenceException if an empty cluster is encountered and the
+ * {@link #emptyStrategy} is set to {@code ERROR}
+ */
+ @Override
+ public List<CentroidCluster<T>> cluster(final Collection<T> points)
+ throws MathIllegalArgumentException, ConvergenceException {
+
+ // sanity checks
+ MathUtils.checkNotNull(points);
+
+ // number of clusters has to be smaller or equal the number of data points
+ if (points.size() < k) {
+ throw new NumberIsTooSmallException(points.size(), k, false);
+ }
+
+ // create the initial clusters
+ List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
+
+ // create an array containing the latest assignment of a point to a cluster
+ // no need to initialize the array, as it will be filled with the first assignment
+ int[] assignments = new int[points.size()];
+ assignPointsToClusters(clusters, points, assignments);
+
+ // iterate through updating the centers until we're done
+ final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
+ for (int count = 0; count < max; count++) {
+ boolean emptyCluster = false;
+ List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
+ for (final CentroidCluster<T> cluster : clusters) {
+ final Clusterable newCenter;
+ if (cluster.getPoints().isEmpty()) {
+ switch (emptyStrategy) {
+ case LARGEST_VARIANCE :
+ newCenter = getPointFromLargestVarianceCluster(clusters);
+ break;
+ case LARGEST_POINTS_NUMBER :
+ newCenter = getPointFromLargestNumberCluster(clusters);
+ break;
+ case FARTHEST_POINT :
+ newCenter = getFarthestPoint(clusters);
+ break;
+ default :
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+ emptyCluster = true;
+ } else {
+ newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
+ }
+ newClusters.add(new CentroidCluster<T>(newCenter));
+ }
+ int changes = assignPointsToClusters(newClusters, points, assignments);
+ clusters = newClusters;
+
+ // if there were no more changes in the point-to-cluster assignment
+ // and there are no empty clusters left, return the current clusters
+ if (changes == 0 && !emptyCluster) {
+ return clusters;
+ }
+ }
+ return clusters;
+ }
+
+ /**
+ * Adds the given points to the closest {@link Cluster}.
+ *
+ * @param clusters the {@link Cluster}s to add the points to
+ * @param points the points to add to the given {@link Cluster}s
+ * @param assignments points assignments to clusters
+ * @return the number of points assigned to different clusters as the iteration before
+ */
+ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
+ final Collection<T> points,
+ final int[] assignments) {
+ int assignedDifferently = 0;
+ int pointIndex = 0;
+ for (final T p : points) {
+ int clusterIndex = getNearestCluster(clusters, p);
+ if (clusterIndex != assignments[pointIndex]) {
+ assignedDifferently++;
+ }
+
+ CentroidCluster<T> cluster = clusters.get(clusterIndex);
+ cluster.addPoint(p);
+ assignments[pointIndex++] = clusterIndex;
+ }
+
+ return assignedDifferently;
+ }
+
+ /**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @param points the points to choose the initial centers from
+ * @return the initial centers
+ */
+ private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
+
+ // Convert to list for indexed access. Make it unmodifiable, since removal of items
+ // would screw up the logic of this method.
+ final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
+
+ // The number of points in the list.
+ final int numPoints = pointList.size();
+
+ // Set the corresponding element in this array to indicate when
+ // elements of pointList are no longer available.
+ final boolean[] taken = new boolean[numPoints];
+
+ // The resulting list of initial centers.
+ final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();
+
+ // Choose one center uniformly at random from among the data points.
+ final int firstPointIndex = random.nextInt(numPoints);
+
+ final T firstPoint = pointList.get(firstPointIndex);
+
+ resultSet.add(new CentroidCluster<T>(firstPoint));
+
+ // Must mark it as taken
+ taken[firstPointIndex] = true;
+
+ // To keep track of the minimum distance squared of elements of
+ // pointList to elements of resultSet.
+ final double[] minDistSquared = new double[numPoints];
+
+ // Initialize the elements. Since the only point in resultSet is firstPoint,
+ // this is very easy.
+ for (int i = 0; i < numPoints; i++) {
+ if (i != firstPointIndex) { // That point isn't considered
+ double d = distance(firstPoint, pointList.get(i));
+ minDistSquared[i] = d*d;
+ }
+ }
+
+ while (resultSet.size() < k) {
+
+ // Sum up the squared distances for the points in pointList not
+ // already taken.
+ double distSqSum = 0.0;
+
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ distSqSum += minDistSquared[i];
+ }
+ }
+
+ // Add one new data point as a center. Each point x is chosen with
+ // probability proportional to D(x)2
+ final double r = random.nextDouble() * distSqSum;
+
+ // The index of the next point to be added to the resultSet.
+ int nextPointIndex = -1;
+
+ // Sum through the squared min distances again, stopping when
+ // sum >= r.
+ double sum = 0.0;
+ for (int i = 0; i < numPoints; i++) {
+ if (!taken[i]) {
+ sum += minDistSquared[i];
+ if (sum >= r) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // If it's not set to >= 0, the point wasn't found in the previous
+ // for loop, probably because distances are extremely small. Just pick
+ // the last available point.
+ if (nextPointIndex == -1) {
+ for (int i = numPoints - 1; i >= 0; i--) {
+ if (!taken[i]) {
+ nextPointIndex = i;
+ break;
+ }
+ }
+ }
+
+ // We found one.
+ if (nextPointIndex >= 0) {
+
+ final T p = pointList.get(nextPointIndex);
+
+ resultSet.add(new CentroidCluster<T> (p));
+
+ // Mark it as taken.
+ taken[nextPointIndex] = true;
+
+ if (resultSet.size() < k) {
+ // Now update elements of minDistSquared. We only have to compute
+ // the distance to the new center to do this.
+ for (int j = 0; j < numPoints; j++) {
+ // Only have to worry about the points still not taken.
+ if (!taken[j]) {
+ double d = distance(p, pointList.get(j));
+ double d2 = d * d;
+ if (d2 < minDistSquared[j]) {
+ minDistSquared[j] = d2;
+ }
+ }
+ }
+ }
+
+ } else {
+ // None found --
+ // Break from the while loop to prevent
+ // an infinite loop.
+ break;
+ }
+ }
+
+ return resultSet;
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest distance variance.
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
+ throws ConvergenceException {
+
+ double maxVariance = Double.NEGATIVE_INFINITY;
+ Cluster<T> selected = null;
+ for (final CentroidCluster<T> cluster : clusters) {
+ if (!cluster.getPoints().isEmpty()) {
+
+ // compute the distance variance of the current cluster
+ final Clusterable center = cluster.getCenter();
+ final Variance stat = new Variance();
+ for (final T point : cluster.getPoints()) {
+ stat.increment(distance(point, center));
+ }
+ final double variance = stat.getResult();
+
+ // select the cluster with the largest variance
+ if (variance > maxVariance) {
+ maxVariance = variance;
+ selected = cluster;
+ }
+
+ }
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List<T> selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest number of points
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
+ throws ConvergenceException {
+
+ int maxNumber = 0;
+ Cluster<T> selected = null;
+ for (final Cluster<T> cluster : clusters) {
+
+ // get the number of points of the current cluster
+ final int number = cluster.getPoints().size();
+
+ // select the cluster with the largest number of points
+ if (number > maxNumber) {
+ maxNumber = number;
+ selected = cluster;
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List<T> selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get the point farthest to its cluster center
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return point farthest to its cluster center
+ * @throws ConvergenceException if clusters are all empty
+ */
+ private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {
+
+ double maxDistance = Double.NEGATIVE_INFINITY;
+ Cluster<T> selectedCluster = null;
+ int selectedPoint = -1;
+ for (final CentroidCluster<T> cluster : clusters) {
+
+ // get the farthest point
+ final Clusterable center = cluster.getCenter();
+ final List<T> points = cluster.getPoints();
+ for (int i = 0; i < points.size(); ++i) {
+ final double distance = distance(points.get(i), center);
+ if (distance > maxDistance) {
+ maxDistance = distance;
+ selectedCluster = cluster;
+ selectedPoint = i;
+ }
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selectedCluster == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ return selectedCluster.getPoints().remove(selectedPoint);
+
+ }
+
+ /**
+ * Returns the nearest {@link Cluster} to the given point
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @param point the point to find the nearest {@link Cluster} for
+ * @return the index of the nearest {@link Cluster} to the given point
+ */
+ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
+ double minDistance = Double.MAX_VALUE;
+ int clusterIndex = 0;
+ int minCluster = 0;
+ for (final CentroidCluster<T> c : clusters) {
+ final double distance = distance(point, c.getCenter());
+ if (distance < minDistance) {
+ minDistance = distance;
+ minCluster = clusterIndex;
+ }
+ clusterIndex++;
+ }
+ return minCluster;
+ }
+
+ /**
+ * Computes the centroid for a set of points.
+ *
+ * @param points the set of points
+ * @param dimension the point dimension
+ * @return the computed centroid for the set of points
+ */
+ private Clusterable centroidOf(final Collection<T> points, final int dimension) {
+ final double[] centroid = new double[dimension];
+ for (final T p : points) {
+ final double[] point = p.getPoint();
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] += point[i];
+ }
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= points.size();
+ }
+ return new DoublePoint(centroid);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java
new file mode 100644
index 0000000..796fc7a
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering;
+
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator;
+import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
+
+/**
+ * A wrapper around a k-means++ clustering algorithm which performs multiple trials
+ * and returns the best solution.
+ * @param <T> type of the points to cluster
+ * @since 3.2
+ */
+public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
+
+ /** The underlying k-means clusterer. */
+ private final KMeansPlusPlusClusterer<T> clusterer;
+
+ /** The number of trial runs. */
+ private final int numTrials;
+
+ /** The cluster evaluator to use. */
+ private final ClusterEvaluator<T> evaluator;
+
+ /** Build a clusterer.
+ * @param clusterer the k-means clusterer to use
+ * @param numTrials number of trial runs
+ */
+ public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
+ final int numTrials) {
+ this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure()));
+ }
+
+ /** Build a clusterer.
+ * @param clusterer the k-means clusterer to use
+ * @param numTrials number of trial runs
+ * @param evaluator the cluster evaluator to use
+ * @since 3.3
+ */
+ public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
+ final int numTrials,
+ final ClusterEvaluator<T> evaluator) {
+ super(clusterer.getDistanceMeasure());
+ this.clusterer = clusterer;
+ this.numTrials = numTrials;
+ this.evaluator = evaluator;
+ }
+
+ /**
+ * Returns the embedded k-means clusterer used by this instance.
+ * @return the embedded clusterer
+ */
+ public KMeansPlusPlusClusterer<T> getClusterer() {
+ return clusterer;
+ }
+
+ /**
+ * Returns the number of trials this instance will do.
+ * @return the number of trials
+ */
+ public int getNumTrials() {
+ return numTrials;
+ }
+
+ /**
+ * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
+ * @return the used {@link ClusterEvaluator}
+ * @since 3.3
+ */
+ public ClusterEvaluator<T> getClusterEvaluator() {
+ return evaluator;
+ }
+
+ /**
+ * Runs the K-means++ clustering algorithm.
+ *
+ * @param points the points to cluster
+ * @return a list of clusters containing the points
+ * @throws MathIllegalArgumentException if the data points are null or the number
+ * of clusters is larger than the number of data points
+ * @throws ConvergenceException if an empty cluster is encountered and the
+ * underlying {@link KMeansPlusPlusClusterer} has its
+ * {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
+ */
+ @Override
+ public List<CentroidCluster<T>> cluster(final Collection<T> points)
+ throws MathIllegalArgumentException, ConvergenceException {
+
+ // at first, we have not found any clusters list yet
+ List<CentroidCluster<T>> best = null;
+ double bestVarianceSum = Double.POSITIVE_INFINITY;
+
+ // do several clustering trials
+ for (int i = 0; i < numTrials; ++i) {
+
+ // compute a clusters list
+ List<CentroidCluster<T>> clusters = clusterer.cluster(points);
+
+ // compute the variance of the current list
+ final double varianceSum = evaluator.score(clusters);
+
+ if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
+ // this one is the best we have found so far, remember it
+ best = clusters;
+ bestVarianceSum = varianceSum;
+ }
+
+ }
+
+ // return the best clusters list found
+ return best;
+
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java
new file mode 100644
index 0000000..2bb8ba3
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering.evaluation;
+
+import java.util.List;
+
+import org.apache.commons.math3.ml.clustering.CentroidCluster;
+import org.apache.commons.math3.ml.clustering.Cluster;
+import org.apache.commons.math3.ml.clustering.Clusterable;
+import org.apache.commons.math3.ml.clustering.DoublePoint;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+
+/**
+ * Base class for cluster evaluation methods.
+ *
+ * @param <T> type of the clustered points
+ * @since 3.3
+ */
+public abstract class ClusterEvaluator<T extends Clusterable> {
+
+ /** The distance measure to use when evaluating the cluster. */
+ private final DistanceMeasure measure;
+
+ /**
+ * Creates a new cluster evaluator with an {@link EuclideanDistance}
+ * as distance measure.
+ */
+ public ClusterEvaluator() {
+ this(new EuclideanDistance());
+ }
+
+ /**
+ * Creates a new cluster evaluator with the given distance measure.
+ * @param measure the distance measure to use
+ */
+ public ClusterEvaluator(final DistanceMeasure measure) {
+ this.measure = measure;
+ }
+
+ /**
+ * Computes the evaluation score for the given list of clusters.
+ * @param clusters the clusters to evaluate
+ * @return the computed score
+ */
+ public abstract double score(List<? extends Cluster<T>> clusters);
+
+ /**
+ * Returns whether the first evaluation score is considered to be better
+ * than the second one by this evaluator.
+ * <p>
+ * Specific implementations shall override this method if the returned scores
+ * do not follow the same ordering, i.e. smaller score is better.
+ *
+ * @param score1 the first score
+ * @param score2 the second score
+ * @return {@code true} if the first score is considered to be better, {@code false} otherwise
+ */
+ public boolean isBetterScore(double score1, double score2) {
+ return score1 < score2;
+ }
+
+ /**
+ * Calculates the distance between two {@link Clusterable} instances
+ * with the configured {@link DistanceMeasure}.
+ *
+ * @param p1 the first clusterable
+ * @param p2 the second clusterable
+ * @return the distance between the two clusterables
+ */
+ protected double distance(final Clusterable p1, final Clusterable p2) {
+ return measure.compute(p1.getPoint(), p2.getPoint());
+ }
+
+ /**
+ * Computes the centroid for a cluster.
+ *
+ * @param cluster the cluster
+ * @return the computed centroid for the cluster,
+ * or {@code null} if the cluster does not contain any points
+ */
+ protected Clusterable centroidOf(final Cluster<T> cluster) {
+ final List<T> points = cluster.getPoints();
+ if (points.isEmpty()) {
+ return null;
+ }
+
+ // in case the cluster is of type CentroidCluster, no need to compute the centroid
+ if (cluster instanceof CentroidCluster) {
+ return ((CentroidCluster<T>) cluster).getCenter();
+ }
+
+ final int dimension = points.get(0).getPoint().length;
+ final double[] centroid = new double[dimension];
+ for (final T p : points) {
+ final double[] point = p.getPoint();
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] += point[i];
+ }
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= points.size();
+ }
+ return new DoublePoint(centroid);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java
new file mode 100644
index 0000000..b5b249c
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.clustering.evaluation;
+
+import java.util.List;
+
+import org.apache.commons.math3.ml.clustering.Cluster;
+import org.apache.commons.math3.ml.clustering.Clusterable;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.stat.descriptive.moment.Variance;
+
+/**
+ * Computes the sum of intra-cluster distance variances according to the formula:
+ * <pre>
+ * \( score = \sum\limits_{i=1}^n \sigma_i^2 \)
+ * </pre>
+ * where n is the number of clusters and \( \sigma_i^2 \) is the variance of
+ * intra-cluster distances of cluster \( c_i \).
+ *
+ * @param <T> the type of the clustered points
+ * @since 3.3
+ */
+public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T> {
+
+ /**
+ *
+ * @param measure the distance measure to use
+ */
+ public SumOfClusterVariances(final DistanceMeasure measure) {
+ super(measure);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public double score(final List<? extends Cluster<T>> clusters) {
+ double varianceSum = 0.0;
+ for (final Cluster<T> cluster : clusters) {
+ if (!cluster.getPoints().isEmpty()) {
+
+ final Clusterable center = centroidOf(cluster);
+
+ // compute the distance variance of the current cluster
+ final Variance stat = new Variance();
+ for (final T point : cluster.getPoints()) {
+ stat.increment(distance(point, center));
+ }
+ varianceSum += stat.getResult();
+
+ }
+ }
+ return varianceSum;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java
new file mode 100644
index 0000000..700f566
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Cluster evaluation methods.
+ */
+package org.apache.commons.math3.ml.clustering.evaluation;
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java
new file mode 100644
index 0000000..02f1d20
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Clustering algorithms.
+ */
+package org.apache.commons.math3.ml.clustering;
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
new file mode 100644
index 0000000..d467c3b
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the Canberra distance between two points.
+ *
+ * @since 3.2
+ */
+public class CanberraDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -6972277381587032228L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ MathArrays.checkEqualLength(a, b);
+ double sum = 0;
+ for (int i = 0; i < a.length; i++) {
+ final double num = FastMath.abs(a[i] - b[i]);
+ final double denom = FastMath.abs(a[i]) + FastMath.abs(b[i]);
+ sum += num == 0.0 && denom == 0.0 ? 0.0 : num / denom;
+ }
+ return sum;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
new file mode 100644
index 0000000..05dccb5
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L<sub>&infin;</sub> (max of abs) distance between two points.
+ *
+ * @since 3.2
+ */
+public class ChebyshevDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -4694868171115238296L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ return MathArrays.distanceInf(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
new file mode 100644
index 0000000..ff9c27f
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import java.io.Serializable;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+
+/**
+ * Interface for distance measures of n-dimensional vectors.
+ *
+ * @since 3.2
+ */
+public interface DistanceMeasure extends Serializable {
+
+ /**
+ * Compute the distance between two n-dimensional vectors.
+ * <p>
+ * The two vectors are required to have the same dimension.
+ *
+ * @param a the first vector
+ * @param b the second vector
+ * @return the distance between the two vectors
+ * @throws DimensionMismatchException if the array lengths differ.
+ */
+ double compute(double[] a, double[] b) throws DimensionMismatchException;
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
new file mode 100644
index 0000000..2518624
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions.
+ *
+ * @see <a href="http://en.wikipedia.org/wiki/Earth_mover's_distance">Earth Mover's distance (Wikipedia)</a>
+ *
+ * @since 3.3
+ */
+public class EarthMoversDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -5406732779747414922L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ MathArrays.checkEqualLength(a, b);
+ double lastDistance = 0;
+ double totalDistance = 0;
+ for (int i = 0; i < a.length; i++) {
+ final double currentDistance = (a[i] + lastDistance) - b[i];
+ totalDistance += FastMath.abs(currentDistance);
+ lastDistance = currentDistance;
+ }
+ return totalDistance;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
new file mode 100644
index 0000000..187badc
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L<sub>2</sub> (Euclidean) distance between two points.
+ *
+ * @since 3.2
+ */
+public class EuclideanDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = 1717556319784040040L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ return MathArrays.distance(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
new file mode 100644
index 0000000..2eebe1b
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.ml.distance;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.MathArrays;
+
+/**
+ * Calculates the L<sub>1</sub> (sum of abs) distance between two points.
+ *
+ * @since 3.2
+ */
+public class ManhattanDistance implements DistanceMeasure {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -9108154600539125566L;
+
+ /** {@inheritDoc} */
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ return MathArrays.distance1(a, b);
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/package-info.java b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java
new file mode 100644
index 0000000..f6d124a
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Common distance measures.
+ */
+package org.apache.commons.math3.ml.distance;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java
new file mode 100644
index 0000000..1f48d45
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Defines how to assign the first value of a neuron's feature.
+ *
+ * @since 3.3
+ */
+public interface FeatureInitializer {
+ /**
+ * Selects the initial value.
+ *
+ * @return the initial value.
+ */
+ double value();
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java
new file mode 100644
index 0000000..f5569b1
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.UniformRealDistribution;
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.function.Constant;
+import org.apache.commons.math3.random.RandomGenerator;
+
+/**
+ * Creates functions that will select the initial values of a neuron's
+ * features.
+ *
+ * @since 3.3
+ */
+public class FeatureInitializerFactory {
+ /** Class contains only static methods. */
+ private FeatureInitializerFactory() {}
+
+ /**
+ * Uniform sampling of the given range.
+ *
+ * @param min Lower bound of the range.
+ * @param max Upper bound of the range.
+ * @param rng Random number generator used to draw samples from a
+ * uniform distribution.
+ * @return an initializer such that the features will be initialized with
+ * values within the given range.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code min >= max}.
+ */
+ public static FeatureInitializer uniform(final RandomGenerator rng,
+ final double min,
+ final double max) {
+ return randomize(new UniformRealDistribution(rng, min, max),
+ function(new Constant(0), 0, 0));
+ }
+
+ /**
+ * Uniform sampling of the given range.
+ *
+ * @param min Lower bound of the range.
+ * @param max Upper bound of the range.
+ * @return an initializer such that the features will be initialized with
+ * values within the given range.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code min >= max}.
+ */
+ public static FeatureInitializer uniform(final double min,
+ final double max) {
+ return randomize(new UniformRealDistribution(min, max),
+ function(new Constant(0), 0, 0));
+ }
+
+ /**
+ * Creates an initializer from a univariate function {@code f(x)}.
+ * The argument {@code x} is set to {@code init} at the first call
+ * and will be incremented at each call.
+ *
+ * @param f Function.
+ * @param init Initial value.
+ * @param inc Increment
+ * @return the initializer.
+ */
+ public static FeatureInitializer function(final UnivariateFunction f,
+ final double init,
+ final double inc) {
+ return new FeatureInitializer() {
+ /** Argument. */
+ private double arg = init;
+
+ /** {@inheritDoc} */
+ public double value() {
+ final double result = f.value(arg);
+ arg += inc;
+ return result;
+ }
+ };
+ }
+
+ /**
+ * Adds some amount of random data to the given initializer.
+ *
+ * @param random Random variable distribution.
+ * @param orig Original initializer.
+ * @return an initializer whose {@link FeatureInitializer#value() value}
+ * method will return {@code orig.value() + random.sample()}.
+ */
+ public static FeatureInitializer randomize(final RealDistribution random,
+ final FeatureInitializer orig) {
+ return new FeatureInitializer() {
+ /** {@inheritDoc} */
+ public double value() {
+ return orig.value() + random.sample();
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java
new file mode 100644
index 0000000..0b7a675
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Comparator;
+
+import org.apache.commons.math3.exception.NoDataException;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * Utilities for network maps.
+ *
+ * @since 3.3
+ */
+public class MapUtils {
+ /**
+ * Class contains only static methods.
+ */
+ private MapUtils() {}
+
+ /**
+ * Finds the neuron that best matches the given features.
+ *
+ * @param features Data.
+ * @param neurons List of neurons to scan. If the list is empty
+ * {@code null} will be returned.
+ * @param distance Distance function. The neuron's features are
+ * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
+ * @return the neuron whose features are closest to the given data.
+ * @throws org.apache.commons.math3.exception.DimensionMismatchException
+ * if the size of the input is not compatible with the neurons features
+ * size.
+ */
+ public static Neuron findBest(double[] features,
+ Iterable<Neuron> neurons,
+ DistanceMeasure distance) {
+ Neuron best = null;
+ double min = Double.POSITIVE_INFINITY;
+ for (final Neuron n : neurons) {
+ final double d = distance.compute(n.getFeatures(), features);
+ if (d < min) {
+ min = d;
+ best = n;
+ }
+ }
+
+ return best;
+ }
+
+ /**
+ * Finds the two neurons that best match the given features.
+ *
+ * @param features Data.
+ * @param neurons List of neurons to scan. If the list is empty
+ * {@code null} will be returned.
+ * @param distance Distance function. The neuron's features are
+ * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
+ * @return the two neurons whose features are closest to the given data.
+ * @throws org.apache.commons.math3.exception.DimensionMismatchException
+ * if the size of the input is not compatible with the neurons features
+ * size.
+ */
+ public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features,
+ Iterable<Neuron> neurons,
+ DistanceMeasure distance) {
+ Neuron[] best = { null, null };
+ double[] min = { Double.POSITIVE_INFINITY,
+ Double.POSITIVE_INFINITY };
+ for (final Neuron n : neurons) {
+ final double d = distance.compute(n.getFeatures(), features);
+ if (d < min[0]) {
+ // Replace second best with old best.
+ min[1] = min[0];
+ best[1] = best[0];
+
+ // Store current as new best.
+ min[0] = d;
+ best[0] = n;
+ } else if (d < min[1]) {
+ // Replace old second best with current.
+ min[1] = d;
+ best[1] = n;
+ }
+ }
+
+ return new Pair<Neuron, Neuron>(best[0], best[1]);
+ }
+
+ /**
+ * Creates a list of neurons sorted in increased order of the distance
+ * to the given {@code features}.
+ *
+ * @param features Data.
+ * @param neurons List of neurons to scan. If it is empty, an empty array
+ * will be returned.
+ * @param distance Distance function.
+ * @return the neurons, sorted in increasing order of distance in data
+ * space.
+ * @throws org.apache.commons.math3.exception.DimensionMismatchException
+ * if the size of the input is not compatible with the neurons features
+ * size.
+ *
+ * @see #findBest(double[],Iterable,DistanceMeasure)
+ * @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure)
+ *
+ * @since 3.6
+ */
+ public static Neuron[] sort(double[] features,
+ Iterable<Neuron> neurons,
+ DistanceMeasure distance) {
+ final List<PairNeuronDouble> list = new ArrayList<PairNeuronDouble>();
+
+ for (final Neuron n : neurons) {
+ final double d = distance.compute(n.getFeatures(), features);
+ list.add(new PairNeuronDouble(n, d));
+ }
+
+ Collections.sort(list, PairNeuronDouble.COMPARATOR);
+
+ final int len = list.size();
+ final Neuron[] sorted = new Neuron[len];
+
+ for (int i = 0; i < len; i++) {
+ sorted[i] = list.get(i).getNeuron();
+ }
+ return sorted;
+ }
+
+ /**
+ * Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix">
+ * U-matrix</a> of a two-dimensional map.
+ *
+ * @param map Network.
+ * @param distance Function to use for computing the average
+ * distance from a neuron to its neighbours.
+ * @return the matrix of average distances.
+ */
+ public static double[][] computeU(NeuronSquareMesh2D map,
+ DistanceMeasure distance) {
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+ final double[][] uMatrix = new double[numRows][numCols];
+
+ final Network net = map.getNetwork();
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final Neuron neuron = map.getNeuron(i, j);
+ final Collection<Neuron> neighbours = net.getNeighbours(neuron);
+ final double[] features = neuron.getFeatures();
+
+ double d = 0;
+ int count = 0;
+ for (Neuron n : neighbours) {
+ ++count;
+ d += distance.compute(features, n.getFeatures());
+ }
+
+ uMatrix[i][j] = d / count;
+ }
+ }
+
+ return uMatrix;
+ }
+
+ /**
+ * Computes the "hit" histogram of a two-dimensional map.
+ *
+ * @param data Feature vectors.
+ * @param map Network.
+ * @param distance Function to use for determining the best matching unit.
+ * @return the number of hits for each neuron in the map.
+ */
+ public static int[][] computeHitHistogram(Iterable<double[]> data,
+ NeuronSquareMesh2D map,
+ DistanceMeasure distance) {
+ final HashMap<Neuron, Integer> hit = new HashMap<Neuron, Integer>();
+ final Network net = map.getNetwork();
+
+ for (double[] f : data) {
+ final Neuron best = findBest(f, net, distance);
+ final Integer count = hit.get(best);
+ if (count == null) {
+ hit.put(best, 1);
+ } else {
+ hit.put(best, count + 1);
+ }
+ }
+
+ // Copy the histogram data into a 2D map.
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+ final int[][] histo = new int[numRows][numCols];
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final Neuron neuron = map.getNeuron(i, j);
+ final Integer count = hit.get(neuron);
+ if (count == null) {
+ histo[i][j] = 0;
+ } else {
+ histo[i][j] = count;
+ }
+ }
+ }
+
+ return histo;
+ }
+
+ /**
+ * Computes the quantization error.
+ * The quantization error is the average distance between a feature vector
+ * and its "best matching unit" (closest neuron).
+ *
+ * @param data Feature vectors.
+ * @param neurons List of neurons to scan.
+ * @param distance Distance function.
+ * @return the error.
+ * @throws NoDataException if {@code data} is empty.
+ */
+ public static double computeQuantizationError(Iterable<double[]> data,
+ Iterable<Neuron> neurons,
+ DistanceMeasure distance) {
+ double d = 0;
+ int count = 0;
+ for (double[] f : data) {
+ ++count;
+ d += distance.compute(f, findBest(f, neurons, distance).getFeatures());
+ }
+
+ if (count == 0) {
+ throw new NoDataException();
+ }
+
+ return d / count;
+ }
+
+ /**
+ * Computes the topographic error.
+ * The topographic error is the proportion of data for which first and
+ * second best matching units are not adjacent in the map.
+ *
+ * @param data Feature vectors.
+ * @param net Network.
+ * @param distance Distance function.
+ * @return the error.
+ * @throws NoDataException if {@code data} is empty.
+ */
+ public static double computeTopographicError(Iterable<double[]> data,
+ Network net,
+ DistanceMeasure distance) {
+ int notAdjacentCount = 0;
+ int count = 0;
+ for (double[] f : data) {
+ ++count;
+ final Pair<Neuron, Neuron> p = findBestAndSecondBest(f, net, distance);
+ if (!net.getNeighbours(p.getFirst()).contains(p.getSecond())) {
+ // Increment count if first and second best matching units
+ // are not neighbours.
+ ++notAdjacentCount;
+ }
+ }
+
+ if (count == 0) {
+ throw new NoDataException();
+ }
+
+ return ((double) notAdjacentCount) / count;
+ }
+
+ /**
+ * Helper data structure holding a (Neuron, double) pair.
+ */
+ private static class PairNeuronDouble {
+ /** Comparator. */
+ static final Comparator<PairNeuronDouble> COMPARATOR
+ = new Comparator<PairNeuronDouble>() {
+ /** {@inheritDoc} */
+ public int compare(PairNeuronDouble o1,
+ PairNeuronDouble o2) {
+ return Double.compare(o1.value, o2.value);
+ }
+ };
+ /** Key. */
+ private final Neuron neuron;
+ /** Value. */
+ private final double value;
+
+ /**
+ * @param neuron Neuron.
+ * @param value Value.
+ */
+ PairNeuronDouble(Neuron neuron, double value) {
+ this.neuron = neuron;
+ this.value = value;
+ }
+
+ /** @return the neuron. */
+ public Neuron getNeuron() {
+ return neuron;
+ }
+
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
new file mode 100644
index 0000000..4b208a3
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
@@ -0,0 +1,499 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import java.util.NoSuchElementException;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Comparator;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.MathIllegalStateException;
+
+/**
+ * Neural network, composed of {@link Neuron} instances and the links
+ * between them.
+ *
+ * Although updating a neuron's state is thread-safe, modifying the
+ * network's topology (adding or removing links) is not.
+ *
+ * @since 3.3
+ */
+public class Network
+ implements Iterable<Neuron>,
+ Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Neurons. */
+ private final ConcurrentHashMap<Long, Neuron> neuronMap
+ = new ConcurrentHashMap<Long, Neuron>();
+ /** Next available neuron identifier. */
+ private final AtomicLong nextId;
+ /** Neuron's features set size. */
+ private final int featureSize;
+ /** Links. */
+ private final ConcurrentHashMap<Long, Set<Long>> linkMap
+ = new ConcurrentHashMap<Long, Set<Long>>();
+
+ /**
+ * Comparator that prescribes an order of the neurons according
+ * to the increasing order of their identifier.
+ */
+ public static class NeuronIdentifierComparator
+ implements Comparator<Neuron>,
+ Serializable {
+ /** Version identifier. */
+ private static final long serialVersionUID = 20130207L;
+
+ /** {@inheritDoc} */
+ public int compare(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+ return aId < bId ? -1 :
+ aId > bId ? 1 : 0;
+ }
+ }
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param nextId Next available identifier.
+ * @param featureSize Number of features.
+ * @param neuronList Neurons.
+ * @param neighbourIdList Links associated to each of the neurons in
+ * {@code neuronList}.
+ * @throws MathIllegalStateException if an inconsistency is detected
+ * (which probably means that the serialized form has been corrupted).
+ */
+ Network(long nextId,
+ int featureSize,
+ Neuron[] neuronList,
+ long[][] neighbourIdList) {
+ final int numNeurons = neuronList.length;
+ if (numNeurons != neighbourIdList.length) {
+ throw new MathIllegalStateException();
+ }
+
+ for (int i = 0; i < numNeurons; i++) {
+ final Neuron n = neuronList[i];
+ final long id = n.getIdentifier();
+ if (id >= nextId) {
+ throw new MathIllegalStateException();
+ }
+ neuronMap.put(id, n);
+ linkMap.put(id, new HashSet<Long>());
+ }
+
+ for (int i = 0; i < numNeurons; i++) {
+ final long aId = neuronList[i].getIdentifier();
+ final Set<Long> aLinks = linkMap.get(aId);
+ for (Long bId : neighbourIdList[i]) {
+ if (neuronMap.get(bId) == null) {
+ throw new MathIllegalStateException();
+ }
+ addLinkToLinkSet(aLinks, bId);
+ }
+ }
+
+ this.nextId = new AtomicLong(nextId);
+ this.featureSize = featureSize;
+ }
+
+ /**
+ * @param initialIdentifier Identifier for the first neuron that
+ * will be added to this network.
+ * @param featureSize Size of the neuron's features.
+ */
+ public Network(long initialIdentifier,
+ int featureSize) {
+ nextId = new AtomicLong(initialIdentifier);
+ this.featureSize = featureSize;
+ }
+
+ /**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ * @since 3.6
+ */
+ public synchronized Network copy() {
+ final Network copy = new Network(nextId.get(),
+ featureSize);
+
+
+ for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
+ copy.neuronMap.put(e.getKey(), e.getValue().copy());
+ }
+
+ for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
+ copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
+ }
+
+ return copy;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ public Iterator<Neuron> iterator() {
+ return neuronMap.values().iterator();
+ }
+
+ /**
+ * Creates a list of the neurons, sorted in a custom order.
+ *
+ * @param comparator {@link Comparator} used for sorting the neurons.
+ * @return a list of neurons, sorted in the order prescribed by the
+ * given {@code comparator}.
+ * @see NeuronIdentifierComparator
+ */
+ public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
+ final List<Neuron> neurons = new ArrayList<Neuron>();
+ neurons.addAll(neuronMap.values());
+
+ Collections.sort(neurons, comparator);
+
+ return neurons;
+ }
+
+ /**
+ * Creates a neuron and assigns it a unique identifier.
+ *
+ * @param features Initial values for the neuron's features.
+ * @return the neuron's identifier.
+ * @throws DimensionMismatchException if the length of {@code features}
+ * is different from the expected size (as set by the
+ * {@link #Network(long,int) constructor}).
+ */
+ public long createNeuron(double[] features) {
+ if (features.length != featureSize) {
+ throw new DimensionMismatchException(features.length, featureSize);
+ }
+
+ final long id = createNextId();
+ neuronMap.put(id, new Neuron(id, features));
+ linkMap.put(id, new HashSet<Long>());
+ return id;
+ }
+
+ /**
+ * Deletes a neuron.
+ * Links from all neighbours to the removed neuron will also be
+ * {@link #deleteLink(Neuron,Neuron) deleted}.
+ *
+ * @param neuron Neuron to be removed from this network.
+ * @throws NoSuchElementException if {@code n} does not belong to
+ * this network.
+ */
+ public void deleteNeuron(Neuron neuron) {
+ final Collection<Neuron> neighbours = getNeighbours(neuron);
+
+ // Delete links to from neighbours.
+ for (Neuron n : neighbours) {
+ deleteLink(n, neuron);
+ }
+
+ // Remove neuron.
+ neuronMap.remove(neuron.getIdentifier());
+ }
+
+ /**
+ * Gets the size of the neurons' features set.
+ *
+ * @return the size of the features set.
+ */
+ public int getFeaturesSize() {
+ return featureSize;
+ }
+
+ /**
+ * Adds a link from neuron {@code a} to neuron {@code b}.
+ * Note: the link is not bi-directional; if a bi-directional link is
+ * required, an additional call must be made with {@code a} and
+ * {@code b} exchanged in the argument list.
+ *
+ * @param a Neuron.
+ * @param b Neuron.
+ * @throws NoSuchElementException if the neurons do not exist in the
+ * network.
+ */
+ public void addLink(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+
+ // Check that the neurons belong to this network.
+ if (a != getNeuron(aId)) {
+ throw new NoSuchElementException(Long.toString(aId));
+ }
+ if (b != getNeuron(bId)) {
+ throw new NoSuchElementException(Long.toString(bId));
+ }
+
+ // Add link from "a" to "b".
+ addLinkToLinkSet(linkMap.get(aId), bId);
+ }
+
+ /**
+ * Adds a link to neuron {@code id} in given {@code linkSet}.
+ * Note: no check verifies that the identifier indeed belongs
+ * to this network.
+ *
+ * @param linkSet Neuron identifier.
+ * @param id Neuron identifier.
+ */
+ private void addLinkToLinkSet(Set<Long> linkSet,
+ long id) {
+ linkSet.add(id);
+ }
+
+ /**
+ * Deletes the link between neurons {@code a} and {@code b}.
+ *
+ * @param a Neuron.
+ * @param b Neuron.
+ * @throws NoSuchElementException if the neurons do not exist in the
+ * network.
+ */
+ public void deleteLink(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+
+ // Check that the neurons belong to this network.
+ if (a != getNeuron(aId)) {
+ throw new NoSuchElementException(Long.toString(aId));
+ }
+ if (b != getNeuron(bId)) {
+ throw new NoSuchElementException(Long.toString(bId));
+ }
+
+ // Delete link from "a" to "b".
+ deleteLinkFromLinkSet(linkMap.get(aId), bId);
+ }
+
+ /**
+ * Deletes a link to neuron {@code id} in given {@code linkSet}.
+ * Note: no check verifies that the identifier indeed belongs
+ * to this network.
+ *
+ * @param linkSet Neuron identifier.
+ * @param id Neuron identifier.
+ */
+ private void deleteLinkFromLinkSet(Set<Long> linkSet,
+ long id) {
+ linkSet.remove(id);
+ }
+
+ /**
+ * Retrieves the neuron with the given (unique) {@code id}.
+ *
+ * @param id Identifier.
+ * @return the neuron associated with the given {@code id}.
+ * @throws NoSuchElementException if the neuron does not exist in the
+ * network.
+ */
+ public Neuron getNeuron(long id) {
+ final Neuron n = neuronMap.get(id);
+ if (n == null) {
+ throw new NoSuchElementException(Long.toString(id));
+ }
+ return n;
+ }
+
+ /**
+ * Retrieves the neurons in the neighbourhood of any neuron in the
+ * {@code neurons} list.
+ * @param neurons Neurons for which to retrieve the neighbours.
+ * @return the list of neighbours.
+ * @see #getNeighbours(Iterable,Iterable)
+ */
+ public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
+ return getNeighbours(neurons, null);
+ }
+
+ /**
+ * Retrieves the neurons in the neighbourhood of any neuron in the
+ * {@code neurons} list.
+ * The {@code exclude} list allows to retrieve the "concentric"
+ * neighbourhoods by removing the neurons that belong to the inner
+ * "circles".
+ *
+ * @param neurons Neurons for which to retrieve the neighbours.
+ * @param exclude Neurons to exclude from the returned list.
+ * Can be {@code null}.
+ * @return the list of neighbours.
+ */
+ public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
+ Iterable<Neuron> exclude) {
+ final Set<Long> idList = new HashSet<Long>();
+
+ for (Neuron n : neurons) {
+ idList.addAll(linkMap.get(n.getIdentifier()));
+ }
+ if (exclude != null) {
+ for (Neuron n : exclude) {
+ idList.remove(n.getIdentifier());
+ }
+ }
+
+ final List<Neuron> neuronList = new ArrayList<Neuron>();
+ for (Long id : idList) {
+ neuronList.add(getNeuron(id));
+ }
+
+ return neuronList;
+ }
+
+ /**
+ * Retrieves the neighbours of the given neuron.
+ *
+ * @param neuron Neuron for which to retrieve the neighbours.
+ * @return the list of neighbours.
+ * @see #getNeighbours(Neuron,Iterable)
+ */
+ public Collection<Neuron> getNeighbours(Neuron neuron) {
+ return getNeighbours(neuron, null);
+ }
+
+ /**
+ * Retrieves the neighbours of the given neuron.
+ *
+ * @param neuron Neuron for which to retrieve the neighbours.
+ * @param exclude Neurons to exclude from the returned list.
+ * Can be {@code null}.
+ * @return the list of neighbours.
+ */
+ public Collection<Neuron> getNeighbours(Neuron neuron,
+ Iterable<Neuron> exclude) {
+ final Set<Long> idList = linkMap.get(neuron.getIdentifier());
+ if (exclude != null) {
+ for (Neuron n : exclude) {
+ idList.remove(n.getIdentifier());
+ }
+ }
+
+ final List<Neuron> neuronList = new ArrayList<Neuron>();
+ for (Long id : idList) {
+ neuronList.add(getNeuron(id));
+ }
+
+ return neuronList;
+ }
+
+ /**
+ * Creates a neuron identifier.
+ *
+ * @return a value that will serve as a unique identifier.
+ */
+ private Long createNextId() {
+ return nextId.getAndIncrement();
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
+ final long[][] neighbourIdList = new long[neuronList.length][];
+
+ for (int i = 0; i < neuronList.length; i++) {
+ final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
+ final long[] neighboursId = new long[neighbours.size()];
+ int count = 0;
+ for (Neuron n : neighbours) {
+ neighboursId[count] = n.getIdentifier();
+ ++count;
+ }
+ neighbourIdList[i] = neighboursId;
+ }
+
+ return new SerializationProxy(nextId.get(),
+ featureSize,
+ neuronList,
+ neighbourIdList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Next identifier. */
+ private final long nextId;
+ /** Number of features. */
+ private final int featureSize;
+ /** Neurons. */
+ private final Neuron[] neuronList;
+ /** Links. */
+ private final long[][] neighbourIdList;
+
+ /**
+ * @param nextId Next available identifier.
+ * @param featureSize Number of features.
+ * @param neuronList Neurons.
+ * @param neighbourIdList Links associated to each of the neurons in
+ * {@code neuronList}.
+ */
+ SerializationProxy(long nextId,
+ int featureSize,
+ Neuron[] neuronList,
+ long[][] neighbourIdList) {
+ this.nextId = nextId;
+ this.featureSize = featureSize;
+ this.neuronList = neuronList;
+ this.neighbourIdList = neighbourIdList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Network} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new Network(nextId,
+ featureSize,
+ neuronList,
+ neighbourIdList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
new file mode 100644
index 0000000..8cae3ea
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.util.Precision;
+
+
+/**
+ * Describes a neuron element of a neural network.
+ *
+ * This class aims to be thread-safe.
+ *
+ * @since 3.3
+ */
+public class Neuron implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Identifier. */
+ private final long identifier;
+ /** Length of the feature set. */
+ private final int size;
+ /** Neuron data. */
+ private final AtomicReference<double[]> features;
+ /** Number of attempts to update a neuron. */
+ private final AtomicLong numberOfAttemptedUpdates = new AtomicLong(0);
+ /** Number of successful updates of a neuron. */
+ private final AtomicLong numberOfSuccessfulUpdates = new AtomicLong(0);
+
+ /**
+ * Creates a neuron.
+ * The size of the feature set is fixed to the length of the given
+ * argument.
+ * <br/>
+ * Constructor is package-private: Neurons must be
+ * {@link Network#createNeuron(double[]) created} by the network
+ * instance to which they will belong.
+ *
+ * @param identifier Identifier (assigned by the {@link Network}).
+ * @param features Initial values of the feature set.
+ */
+ Neuron(long identifier,
+ double[] features) {
+ this.identifier = identifier;
+ this.size = features.length;
+ this.features = new AtomicReference<double[]>(features.clone());
+ }
+
+ /**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ * @since 3.6
+ */
+ public synchronized Neuron copy() {
+ final Neuron copy = new Neuron(getIdentifier(),
+ getFeatures());
+ copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
+ copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());
+
+ return copy;
+ }
+
+ /**
+ * Gets the neuron's identifier.
+ *
+ * @return the identifier.
+ */
+ public long getIdentifier() {
+ return identifier;
+ }
+
+ /**
+ * Gets the length of the feature set.
+ *
+ * @return the number of features.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ /**
+ * Gets the neuron's features.
+ *
+ * @return a copy of the neuron's features.
+ */
+ public double[] getFeatures() {
+ return features.get().clone();
+ }
+
+ /**
+ * Tries to atomically update the neuron's features.
+ * Update will be performed only if the expected values match the
+ * current values.<br/>
+ * In effect, when concurrent threads call this method, the state
+ * could be modified by one, so that it does not correspond to the
+ * the state assumed by another.
+ * Typically, a caller {@link #getFeatures() retrieves the current state},
+ * and uses it to compute the new state.
+ * During this computation, another thread might have done the same
+ * thing, and updated the state: If the current thread were to proceed
+ * with its own update, it would overwrite the new state (which might
+ * already have been used by yet other threads).
+ * To prevent this, the method does not perform the update when a
+ * concurrent modification has been detected, and returns {@code false}.
+ * When this happens, the caller should fetch the new current state,
+ * redo its computation, and call this method again.
+ *
+ * @param expect Current values of the features, as assumed by the caller.
+ * Update will never succeed if the contents of this array does not match
+ * the values returned by {@link #getFeatures()}.
+ * @param update Features's new values.
+ * @return {@code true} if the update was successful, {@code false}
+ * otherwise.
+ * @throws DimensionMismatchException if the length of {@code update} is
+ * not the same as specified in the {@link #Neuron(long,double[])
+ * constructor}.
+ */
+ public boolean compareAndSetFeatures(double[] expect,
+ double[] update) {
+ if (update.length != size) {
+ throw new DimensionMismatchException(update.length, size);
+ }
+
+ // Get the internal reference. Note that this must not be a copy;
+ // otherwise the "compareAndSet" below will always fail.
+ final double[] current = features.get();
+ if (!containSameValues(current, expect)) {
+ // Some other thread already modified the state.
+ return false;
+ }
+
+ // Increment attempt counter.
+ numberOfAttemptedUpdates.incrementAndGet();
+
+ if (features.compareAndSet(current, update.clone())) {
+ // The current thread could atomically update the state (attempt succeeded).
+ numberOfSuccessfulUpdates.incrementAndGet();
+ return true;
+ } else {
+ // Some other thread came first (attempt failed).
+ return false;
+ }
+ }
+
+ /**
+ * Retrieves the number of calls to the
+ * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures}
+ * method.
+ * Note that if the caller wants to use this method in combination with
+ * {@link #getNumberOfSuccessfulUpdates()}, additional synchronization
+ * may be required to ensure consistency.
+ *
+ * @return the number of update attempts.
+ * @since 3.6
+ */
+ public long getNumberOfAttemptedUpdates() {
+ return numberOfAttemptedUpdates.get();
+ }
+
+ /**
+ * Retrieves the number of successful calls to the
+ * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures}
+ * method.
+ * Note that if the caller wants to use this method in combination with
+ * {@link #getNumberOfAttemptedUpdates()}, additional synchronization
+ * may be required to ensure consistency.
+ *
+ * @return the number of successful updates.
+ * @since 3.6
+ */
+ public long getNumberOfSuccessfulUpdates() {
+ return numberOfSuccessfulUpdates.get();
+ }
+
+ /**
+ * Checks whether the contents of both arrays is the same.
+ *
+ * @param current Current values.
+ * @param expect Expected values.
+ * @throws DimensionMismatchException if the length of {@code expected}
+ * is not the same as specified in the {@link #Neuron(long,double[])
+ * constructor}.
+ * @return {@code true} if the arrays contain the same values.
+ */
+ private boolean containSameValues(double[] current,
+ double[] expect) {
+ if (expect.length != size) {
+ throw new DimensionMismatchException(expect.length, size);
+ }
+
+ for (int i = 0; i < size; i++) {
+ if (!Precision.equals(current[i], expect[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ return new SerializationProxy(identifier,
+ features.get());
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Features. */
+ private final double[] features;
+ /** Identifier. */
+ private final long identifier;
+
+ /**
+ * @param identifier Identifier.
+ * @param features Features.
+ */
+ SerializationProxy(long identifier,
+ double[] features) {
+ this.identifier = identifier;
+ this.features = features;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new Neuron(identifier,
+ features);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java
new file mode 100644
index 0000000..a3c0d95
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Defines neighbourhood types.
+ *
+ * @since 3.3
+ */
+public enum SquareNeighbourhood {
+ /**
+ * <a href="http://en.wikipedia.org/wiki/Von_Neumann_neighborhood"
+ * Von Neumann neighbourhood</a>: in two dimensions, each (internal)
+ * neuron has four neighbours.
+ */
+ VON_NEUMANN,
+ /**
+ * <a href="http://en.wikipedia.org/wiki/Moore_neighborhood"
+ * Moore neighbourhood</a>: in two dimensions, each (internal)
+ * neuron has eight neighbours.
+ */
+ MOORE,
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java
new file mode 100644
index 0000000..041d3d6
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Describes how to update the network in response to a training
+ * sample.
+ *
+ * @since 3.3
+ */
+public interface UpdateAction {
+ /**
+ * Updates the network in response to the sample {@code features}.
+ *
+ * @param net Network.
+ * @param features Training data.
+ */
+ void update(Network net, double[] features);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java
new file mode 100644
index 0000000..fad6042
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.oned;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.OutOfRangeException;
+
+/**
+ * Neural network with the topology of a one-dimensional line.
+ * Each neuron defines one point on the line.
+ *
+ * @since 3.3
+ */
+public class NeuronString implements Serializable {
+ /** Serial version ID */
+ private static final long serialVersionUID = 1L;
+ /** Underlying network. */
+ private final Network network;
+ /** Number of neurons. */
+ private final int size;
+ /** Wrap. */
+ private final boolean wrap;
+
+ /**
+ * Mapping of the 1D coordinate to the neuron identifiers
+ * (attributed by the {@link #network} instance).
+ */
+ private final long[] identifiers;
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param wrap Whether to wrap the dimension (i.e the first and last
+ * neurons will be linked together).
+ * @param featuresList Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code num < 2}.
+ */
+ NeuronString(boolean wrap,
+ double[][] featuresList) {
+ size = featuresList.length;
+
+ if (size < 2) {
+ throw new NumberIsTooSmallException(size, 2, true);
+ }
+
+ this.wrap = wrap;
+
+ final int fLen = featuresList[0].length;
+ network = new Network(0, fLen);
+ identifiers = new long[size];
+
+ // Add neurons.
+ for (int i = 0; i < size; i++) {
+ identifiers[i] = network.createNeuron(featuresList[i]);
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Creates a one-dimensional network:
+ * Each neuron not located on the border of the mesh has two
+ * neurons linked to it.
+ * <br/>
+ * The links are bi-directional.
+ * Neurons created successively are neighbours (i.e. there are
+ * links between them).
+ * <br/>
+ * The topology of the network can also be a circle (if the
+ * dimension is wrapped).
+ *
+ * @param num Number of neurons.
+ * @param wrap Whether to wrap the dimension (i.e the first and last
+ * neurons will be linked together).
+ * @param featureInit Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code num < 2}.
+ */
+ public NeuronString(int num,
+ boolean wrap,
+ FeatureInitializer[] featureInit) {
+ if (num < 2) {
+ throw new NumberIsTooSmallException(num, 2, true);
+ }
+
+ size = num;
+ this.wrap = wrap;
+ identifiers = new long[num];
+
+ final int fLen = featureInit.length;
+ network = new Network(0, fLen);
+
+ // Add neurons.
+ for (int i = 0; i < num; i++) {
+ final double[] features = new double[fLen];
+ for (int fIndex = 0; fIndex < fLen; fIndex++) {
+ features[fIndex] = featureInit[fIndex].value();
+ }
+ identifiers[i] = network.createNeuron(features);
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Retrieves the underlying network.
+ * A reference is returned (enabling, for example, the network to be
+ * trained).
+ * This also implies that calling methods that modify the {@link Network}
+ * topology may cause this class to become inconsistent.
+ *
+ * @return the network.
+ */
+ public Network getNetwork() {
+ return network;
+ }
+
+ /**
+ * Gets the number of neurons.
+ *
+ * @return the number of neurons.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ /**
+ * Retrieves the features set from the neuron at location
+ * {@code i} in the map.
+ *
+ * @param i Neuron index.
+ * @return the features of the neuron at index {@code i}.
+ * @throws OutOfRangeException if {@code i} is out of range.
+ */
+ public double[] getFeatures(int i) {
+ if (i < 0 ||
+ i >= size) {
+ throw new OutOfRangeException(i, 0, size - 1);
+ }
+
+ return network.getNeuron(identifiers[i]).getFeatures();
+ }
+
+ /**
+ * Creates the neighbour relationships between neurons.
+ */
+ private void createLinks() {
+ for (int i = 0; i < size - 1; i++) {
+ network.addLink(network.getNeuron(i), network.getNeuron(i + 1));
+ }
+ for (int i = size - 1; i > 0; i--) {
+ network.addLink(network.getNeuron(i), network.getNeuron(i - 1));
+ }
+ if (wrap) {
+ network.addLink(network.getNeuron(0), network.getNeuron(size - 1));
+ network.addLink(network.getNeuron(size - 1), network.getNeuron(0));
+ }
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final double[][] featuresList = new double[size][];
+ for (int i = 0; i < size; i++) {
+ featuresList[i] = getFeatures(i);
+ }
+
+ return new SerializationProxy(wrap,
+ featuresList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130226L;
+ /** Wrap. */
+ private final boolean wrap;
+ /** Neurons' features. */
+ private final double[][] featuresList;
+
+ /**
+ * @param wrap Whether the dimension is wrapped.
+ * @param featuresList List of neurons features.
+ * {@code neuronList}.
+ */
+ SerializationProxy(boolean wrap,
+ double[][] featuresList) {
+ this.wrap = wrap;
+ this.featuresList = featuresList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new NeuronString(wrap,
+ featuresList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java
new file mode 100644
index 0000000..0b47fae
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * One-dimensional neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.oned;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java
new file mode 100644
index 0000000..d8e907e
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java
new file mode 100644
index 0000000..9aa497d
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.Iterator;
+import org.apache.commons.math3.ml.neuralnet.Network;
+
+/**
+ * Trainer for Kohonen's Self-Organizing Map.
+ *
+ * @since 3.3
+ */
+public class KohonenTrainingTask implements Runnable {
+ /** SOFM to be trained. */
+ private final Network net;
+ /** Training data. */
+ private final Iterator<double[]> featuresIterator;
+ /** Update procedure. */
+ private final KohonenUpdateAction updateAction;
+
+ /**
+ * Creates a (sequential) trainer for the given network.
+ *
+ * @param net Network to be trained with the SOFM algorithm.
+ * @param featuresIterator Training data iterator.
+ * @param updateAction SOFM update procedure.
+ */
+ public KohonenTrainingTask(Network net,
+ Iterator<double[]> featuresIterator,
+ KohonenUpdateAction updateAction) {
+ this.net = net;
+ this.featuresIterator = featuresIterator;
+ this.updateAction = updateAction;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ public void run() {
+ while (featuresIterator.hasNext()) {
+ updateAction.update(net, featuresIterator.next());
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java
new file mode 100644
index 0000000..0618aeb
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.apache.commons.math3.analysis.function.Gaussian;
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.UpdateAction;
+
+/**
+ * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
+ * Kohonen's Self-Organizing Map</a>.
+ * <br/>
+ * The {@link #update(Network,double[]) update} method modifies the
+ * features {@code w} of the "winning" neuron and its neighbours
+ * according to the following rule:
+ * <code>
+ * w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
+ * </code>
+ * where
+ * <ul>
+ * <li>&alpha; is the current <em>learning rate</em>, </li>
+ * <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
+ * <li>{@code d} is the number of links to traverse in order to reach
+ * the neuron from the winning neuron.</li>
+ * </ul>
+ * <br/>
+ * This class is thread-safe as long as the arguments passed to the
+ * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
+ * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
+ * classes.
+ * <br/>
+ * Each call to the {@link #update(Network,double[]) update} method
+ * will increment the internal counter used to compute the current
+ * values for
+ * <ul>
+ * <li>the <em>learning rate</em>, and</li>
+ * <li>the <em>neighbourhood size</em>.</li>
+ * </ul>
+ * Consequently, the function instances that compute those values (passed
+ * to the constructor of this class) must take into account whether this
+ * class's instance will be shared by multiple threads, as this will impact
+ * the training process.
+ *
+ * @since 3.3
+ */
+public class KohonenUpdateAction implements UpdateAction {
+ /** Distance function. */
+ private final DistanceMeasure distance;
+ /** Learning factor update function. */
+ private final LearningFactorFunction learningFactor;
+ /** Neighbourhood size update function. */
+ private final NeighbourhoodSizeFunction neighbourhoodSize;
+ /** Number of calls to {@link #update(Network,double[])}. */
+ private final AtomicLong numberOfCalls = new AtomicLong(0);
+
+ /**
+ * @param distance Distance function.
+ * @param learningFactor Learning factor update function.
+ * @param neighbourhoodSize Neighbourhood size update function.
+ */
+ public KohonenUpdateAction(DistanceMeasure distance,
+ LearningFactorFunction learningFactor,
+ NeighbourhoodSizeFunction neighbourhoodSize) {
+ this.distance = distance;
+ this.learningFactor = learningFactor;
+ this.neighbourhoodSize = neighbourhoodSize;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ public void update(Network net,
+ double[] features) {
+ final long numCalls = numberOfCalls.incrementAndGet() - 1;
+ final double currentLearning = learningFactor.value(numCalls);
+ final Neuron best = findAndUpdateBestNeuron(net,
+ features,
+ currentLearning);
+
+ final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
+ // The farther away the neighbour is from the winning neuron, the
+ // smaller the learning rate will become.
+ final Gaussian neighbourhoodDecay
+ = new Gaussian(currentLearning,
+ 0,
+ currentNeighbourhood);
+
+ if (currentNeighbourhood > 0) {
+ // Initial set of neurons only contains the winning neuron.
+ Collection<Neuron> neighbours = new HashSet<Neuron>();
+ neighbours.add(best);
+ // Winning neuron must be excluded from the neighbours.
+ final HashSet<Neuron> exclude = new HashSet<Neuron>();
+ exclude.add(best);
+
+ int radius = 1;
+ do {
+ // Retrieve immediate neighbours of the current set of neurons.
+ neighbours = net.getNeighbours(neighbours, exclude);
+
+ // Update all the neighbours.
+ for (Neuron n : neighbours) {
+ updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius));
+ }
+
+ // Add the neighbours to the exclude list so that they will
+ // not be update more than once per training step.
+ exclude.addAll(neighbours);
+ ++radius;
+ } while (radius <= currentNeighbourhood);
+ }
+ }
+
+ /**
+ * Retrieves the number of calls to the {@link #update(Network,double[]) update}
+ * method.
+ *
+ * @return the current number of calls.
+ */
+ public long getNumberOfCalls() {
+ return numberOfCalls.get();
+ }
+
+ /**
+ * Tries to update a neuron.
+ *
+ * @param n Neuron to be updated.
+ * @param features Training data.
+ * @param learningRate Learning factor.
+ * @return {@code true} if the update succeeded, {@code true} if a
+ * concurrent update has been detected.
+ */
+ private boolean attemptNeuronUpdate(Neuron n,
+ double[] features,
+ double learningRate) {
+ final double[] expect = n.getFeatures();
+ final double[] update = computeFeatures(expect,
+ features,
+ learningRate);
+
+ return n.compareAndSetFeatures(expect, update);
+ }
+
+ /**
+ * Atomically updates the given neuron.
+ *
+ * @param n Neuron to be updated.
+ * @param features Training data.
+ * @param learningRate Learning factor.
+ */
+ private void updateNeighbouringNeuron(Neuron n,
+ double[] features,
+ double learningRate) {
+ while (true) {
+ if (attemptNeuronUpdate(n, features, learningRate)) {
+ break;
+ }
+ }
+ }
+
+ /**
+ * Searches for the neuron whose features are closest to the given
+ * sample, and atomically updates its features.
+ *
+ * @param net Network.
+ * @param features Sample data.
+ * @param learningRate Current learning factor.
+ * @return the winning neuron.
+ */
+ private Neuron findAndUpdateBestNeuron(Network net,
+ double[] features,
+ double learningRate) {
+ while (true) {
+ final Neuron best = MapUtils.findBest(features, net, distance);
+
+ if (attemptNeuronUpdate(best, features, learningRate)) {
+ return best;
+ }
+
+ // If another thread modified the state of the winning neuron,
+ // it may not be the best match anymore for the given training
+ // sample: Hence, the winner search is performed again.
+ }
+ }
+
+ /**
+ * Computes the new value of the features set.
+ *
+ * @param current Current values of the features.
+ * @param sample Training data.
+ * @param learningRate Learning factor.
+ * @return the new values for the features.
+ */
+ private double[] computeFeatures(double[] current,
+ double[] sample,
+ double learningRate) {
+ final ArrayRealVector c = new ArrayRealVector(current, false);
+ final ArrayRealVector s = new ArrayRealVector(sample, false);
+ // c + learningRate * (s - c)
+ return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray();
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java
new file mode 100644
index 0000000..ba9d152
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+/**
+ * Provides the learning rate as a function of the number of calls
+ * already performed during the learning task.
+ *
+ * @since 3.3
+ */
+public interface LearningFactorFunction {
+ /**
+ * Computes the learning rate at the current call.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ double value(long numCall);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java
new file mode 100644
index 0000000..9165e82
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
+import org.apache.commons.math3.exception.OutOfRangeException;
+
+/**
+ * Factory for creating instances of {@link LearningFactorFunction}.
+ *
+ * @since 3.3
+ */
+public class LearningFactorFunctionFactory {
+ /** Class contains only static methods. */
+ private LearningFactorFunctionFactory() {}
+
+ /**
+ * Creates an exponential decay {@link LearningFactorFunction function}.
+ * It will compute <code>a e<sup>-x / b</sup></code>,
+ * where {@code x} is the (integer) independent variable and
+ * <ul>
+ * <li><code>a = initValue</code>
+ * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e.
+ * {@link LearningFactorFunction#value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @return the learning factor function.
+ * @throws org.apache.commons.math3.exception.OutOfRangeException
+ * if {@code initValue <= 0} or {@code initValue > 1}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code valueAtNumCall <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code valueAtNumCall >= initValue}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static LearningFactorFunction exponentialDecay(final double initValue,
+ final double valueAtNumCall,
+ final long numCall) {
+ if (initValue <= 0 ||
+ initValue > 1) {
+ throw new OutOfRangeException(initValue, 0, 1);
+ }
+
+ return new LearningFactorFunction() {
+ /** DecayFunction. */
+ private final ExponentialDecayFunction decay
+ = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
+
+ /** {@inheritDoc} */
+ public double value(long n) {
+ return decay.value(n);
+ }
+ };
+ }
+
+ /**
+ * Creates an sigmoid-like {@code LearningFactorFunction function}.
+ * The function {@code f} will have the following properties:
+ * <ul>
+ * <li>{@code f(0) = initValue}</li>
+ * <li>{@code numCall} is the inflexion point</li>
+ * <li>{@code slope = f'(numCall)}</li>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e.
+ * {@link LearningFactorFunction#value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @return the learning factor function.
+ * @throws org.apache.commons.math3.exception.OutOfRangeException
+ * if {@code initValue <= 0} or {@code initValue > 1}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code slope >= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static LearningFactorFunction quasiSigmoidDecay(final double initValue,
+ final double slope,
+ final long numCall) {
+ if (initValue <= 0 ||
+ initValue > 1) {
+ throw new OutOfRangeException(initValue, 0, 1);
+ }
+
+ return new LearningFactorFunction() {
+ /** DecayFunction. */
+ private final QuasiSigmoidDecayFunction decay
+ = new QuasiSigmoidDecayFunction(initValue, slope, numCall);
+
+ /** {@inheritDoc} */
+ public double value(long n) {
+ return decay.value(n);
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java
new file mode 100644
index 0000000..68149f7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+/**
+ * Provides the network neighbourhood's size as a function of the
+ * number of calls already performed during the learning task.
+ * The "neighbourhood" is the set of neurons that can be reached
+ * by traversing at most the number of links returned by this
+ * function.
+ *
+ * @since 3.3
+ */
+public interface NeighbourhoodSizeFunction {
+ /**
+ * Computes the neighbourhood size at the current call.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ int value(long numCall);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java
new file mode 100644
index 0000000..bdbfa2f
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * Factory for creating instances of {@link NeighbourhoodSizeFunction}.
+ *
+ * @since 3.3
+ */
+public class NeighbourhoodSizeFunctionFactory {
+ /** Class contains only static methods. */
+ private NeighbourhoodSizeFunctionFactory() {}
+
+ /**
+ * Creates an exponential decay {@link NeighbourhoodSizeFunction function}.
+ * It will compute <code>a e<sup>-x / b</sup></code>,
+ * where {@code x} is the (integer) independent variable and
+ * <ul>
+ * <li><code>a = initValue</code>
+ * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e.
+ * {@link NeighbourhoodSizeFunction#value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @return the neighbourhood size function.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code initValue <= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code valueAtNumCall <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code valueAtNumCall >= initValue}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static NeighbourhoodSizeFunction exponentialDecay(final double initValue,
+ final double valueAtNumCall,
+ final long numCall) {
+ return new NeighbourhoodSizeFunction() {
+ /** DecayFunction. */
+ private final ExponentialDecayFunction decay
+ = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
+
+ /** {@inheritDoc} */
+ public int value(long n) {
+ return (int) FastMath.rint(decay.value(n));
+ }
+ };
+ }
+
+ /**
+ * Creates an sigmoid-like {@code NeighbourhoodSizeFunction function}.
+ * The function {@code f} will have the following properties:
+ * <ul>
+ * <li>{@code f(0) = initValue}</li>
+ * <li>{@code numCall} is the inflexion point</li>
+ * <li>{@code slope = f'(numCall)}</li>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e.
+ * {@link NeighbourhoodSizeFunction#value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @return the neighbourhood size function.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code initValue <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code slope >= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static NeighbourhoodSizeFunction quasiSigmoidDecay(final double initValue,
+ final double slope,
+ final long numCall) {
+ return new NeighbourhoodSizeFunction() {
+ /** DecayFunction. */
+ private final QuasiSigmoidDecayFunction decay
+ = new QuasiSigmoidDecayFunction(initValue, slope, numCall);
+
+ /** {@inheritDoc} */
+ public int value(long n) {
+ return (int) FastMath.rint(decay.value(n));
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java
new file mode 100644
index 0000000..60c3c61
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Self Organizing Feature Map.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java
new file mode 100644
index 0000000..19e7380
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * Exponential decay function: <code>a e<sup>-x / b</sup></code>,
+ * where {@code x} is the (integer) independent variable.
+ * <br/>
+ * Class is immutable.
+ *
+ * @since 3.3
+ */
+public class ExponentialDecayFunction {
+ /** Factor {@code a}. */
+ private final double a;
+ /** Factor {@code 1 / b}. */
+ private final double oneOverB;
+
+ /**
+ * Creates an instance. It will be such that
+ * <ul>
+ * <li>{@code a = initValue}</li>
+ * <li>{@code b = -numCall / ln(valueAtNumCall / initValue)}</li>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e. {@link #value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @throws NotStrictlyPositiveException if {@code initValue <= 0}.
+ * @throws NotStrictlyPositiveException if {@code valueAtNumCall <= 0}.
+ * @throws NumberIsTooLargeException if {@code valueAtNumCall >= initValue}.
+ * @throws NotStrictlyPositiveException if {@code numCall <= 0}.
+ */
+ public ExponentialDecayFunction(double initValue,
+ double valueAtNumCall,
+ long numCall) {
+ if (initValue <= 0) {
+ throw new NotStrictlyPositiveException(initValue);
+ }
+ if (valueAtNumCall <= 0) {
+ throw new NotStrictlyPositiveException(valueAtNumCall);
+ }
+ if (valueAtNumCall >= initValue) {
+ throw new NumberIsTooLargeException(valueAtNumCall, initValue, false);
+ }
+ if (numCall <= 0) {
+ throw new NotStrictlyPositiveException(numCall);
+ }
+
+ a = initValue;
+ oneOverB = -FastMath.log(valueAtNumCall / initValue) / numCall;
+ }
+
+ /**
+ * Computes <code>a e<sup>-numCall / b</sup></code>.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ public double value(long numCall) {
+ return a * FastMath.exp(-numCall * oneOverB);
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java
new file mode 100644
index 0000000..3d35c17
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.analysis.function.Logistic;
+
+/**
+ * Decay function whose shape is similar to a sigmoid.
+ * <br/>
+ * Class is immutable.
+ *
+ * @since 3.3
+ */
+public class QuasiSigmoidDecayFunction {
+ /** Sigmoid. */
+ private final Logistic sigmoid;
+ /** See {@link #value(long)}. */
+ private final double scale;
+
+ /**
+ * Creates an instance.
+ * The function {@code f} will have the following properties:
+ * <ul>
+ * <li>{@code f(0) = initValue}</li>
+ * <li>{@code numCall} is the inflexion point</li>
+ * <li>{@code slope = f'(numCall)}</li>
+ * </ul>
+ *
+ * @param initValue Initial value, i.e. {@link #value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @throws NotStrictlyPositiveException if {@code initValue <= 0}.
+ * @throws NumberIsTooLargeException if {@code slope >= 0}.
+ * @throws NotStrictlyPositiveException if {@code numCall <= 0}.
+ */
+ public QuasiSigmoidDecayFunction(double initValue,
+ double slope,
+ long numCall) {
+ if (initValue <= 0) {
+ throw new NotStrictlyPositiveException(initValue);
+ }
+ if (slope >= 0) {
+ throw new NumberIsTooLargeException(slope, 0, false);
+ }
+ if (numCall <= 1) {
+ throw new NotStrictlyPositiveException(numCall);
+ }
+
+ final double k = initValue;
+ final double m = numCall;
+ final double b = 4 * slope / initValue;
+ final double q = 1;
+ final double a = 0;
+ final double n = 1;
+ sigmoid = new Logistic(k, m, b, q, a, n);
+
+ final double y0 = sigmoid.value(0);
+ scale = k / y0;
+ }
+
+ /**
+ * Computes the value of the learning factor.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ public double value(long numCall) {
+ return scale * sigmoid.value(numCall);
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java
new file mode 100644
index 0000000..5078ed2
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Miscellaneous utilities.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
new file mode 100644
index 0000000..5277bc5
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
@@ -0,0 +1,628 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod;
+
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.OutOfRangeException;
+import org.apache.commons.math3.exception.MathInternalError;
+
+/**
+ * Neural network with the topology of a two-dimensional surface.
+ * Each neuron defines one surface element.
+ * <br/>
+ * This network is primarily intended to represent a
+ * <a href="http://en.wikipedia.org/wiki/Kohonen">
+ * Self Organizing Feature Map</a>.
+ *
+ * @see org.apache.commons.math3.ml.neuralnet.sofm
+ * @since 3.3
+ */
+public class NeuronSquareMesh2D
+ implements Iterable<Neuron>,
+ Serializable {
+ /** Serial version ID */
+ private static final long serialVersionUID = 1L;
+ /** Underlying network. */
+ private final Network network;
+ /** Number of rows. */
+ private final int numberOfRows;
+ /** Number of columns. */
+ private final int numberOfColumns;
+ /** Wrap. */
+ private final boolean wrapRows;
+ /** Wrap. */
+ private final boolean wrapColumns;
+ /** Neighbourhood type. */
+ private final SquareNeighbourhood neighbourhood;
+ /**
+ * Mapping of the 2D coordinates (in the rectangular mesh) to
+ * the neuron identifiers (attributed by the {@link #network}
+ * instance).
+ */
+ private final long[][] identifiers;
+
+ /**
+ * Horizontal (along row) direction.
+ * @since 3.6
+ */
+ public enum HorizontalDirection {
+ /** Column at the right of the current column. */
+ RIGHT,
+ /** Current column. */
+ CENTER,
+ /** Column at the left of the current column. */
+ LEFT,
+ }
+ /**
+ * Vertical (along column) direction.
+ * @since 3.6
+ */
+ public enum VerticalDirection {
+ /** Row above the current row. */
+ UP,
+ /** Current row. */
+ CENTER,
+ /** Row below the current row. */
+ DOWN,
+ }
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param featuresList Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code numRows < 2} or
+ * {@code numCols < 2}.
+ */
+ NeuronSquareMesh2D(boolean wrapRowDim,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ double[][][] featuresList) {
+ numberOfRows = featuresList.length;
+ numberOfColumns = featuresList[0].length;
+
+ if (numberOfRows < 2) {
+ throw new NumberIsTooSmallException(numberOfRows, 2, true);
+ }
+ if (numberOfColumns < 2) {
+ throw new NumberIsTooSmallException(numberOfColumns, 2, true);
+ }
+
+ wrapRows = wrapRowDim;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+
+ final int fLen = featuresList[0][0].length;
+ network = new Network(0, fLen);
+ identifiers = new long[numberOfRows][numberOfColumns];
+
+ // Add neurons.
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ identifiers[i][j] = network.createNeuron(featuresList[i][j]);
+ }
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Creates a two-dimensional network composed of square cells:
+ * Each neuron not located on the border of the mesh has four
+ * neurons linked to it.
+ * <br/>
+ * The links are bi-directional.
+ * <br/>
+ * The topology of the network can also be a cylinder (if one
+ * of the dimensions is wrapped) or a torus (if both dimensions
+ * are wrapped).
+ *
+ * @param numRows Number of neurons in the first dimension.
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param numCols Number of neurons in the second dimension.
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param featureInit Array of functions that will initialize the
+ * corresponding element of the features set of each newly created
+ * neuron. In particular, the size of this array defines the size of
+ * feature set.
+ * @throws NumberIsTooSmallException if {@code numRows < 2} or
+ * {@code numCols < 2}.
+ */
+ public NeuronSquareMesh2D(int numRows,
+ boolean wrapRowDim,
+ int numCols,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ FeatureInitializer[] featureInit) {
+ if (numRows < 2) {
+ throw new NumberIsTooSmallException(numRows, 2, true);
+ }
+ if (numCols < 2) {
+ throw new NumberIsTooSmallException(numCols, 2, true);
+ }
+
+ numberOfRows = numRows;
+ wrapRows = wrapRowDim;
+ numberOfColumns = numCols;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+ identifiers = new long[numberOfRows][numberOfColumns];
+
+ final int fLen = featureInit.length;
+ network = new Network(0, fLen);
+
+ // Add neurons.
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final double[] features = new double[fLen];
+ for (int fIndex = 0; fIndex < fLen; fIndex++) {
+ features[fIndex] = featureInit[fIndex].value();
+ }
+ identifiers[i][j] = network.createNeuron(features);
+ }
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Constructor with restricted access, solely used for making a
+ * {@link #copy() deep copy}.
+ *
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param net Underlying network.
+ * @param idGrid Neuron identifiers.
+ */
+ private NeuronSquareMesh2D(boolean wrapRowDim,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ Network net,
+ long[][] idGrid) {
+ numberOfRows = idGrid.length;
+ numberOfColumns = idGrid[0].length;
+ wrapRows = wrapRowDim;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+ network = net;
+ identifiers = idGrid;
+ }
+
+ /**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ * @since 3.6
+ */
+ public synchronized NeuronSquareMesh2D copy() {
+ final long[][] idGrid = new long[numberOfRows][numberOfColumns];
+ for (int r = 0; r < numberOfRows; r++) {
+ for (int c = 0; c < numberOfColumns; c++) {
+ idGrid[r][c] = identifiers[r][c];
+ }
+ }
+
+ return new NeuronSquareMesh2D(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ network.copy(),
+ idGrid);
+ }
+
+ /**
+ * {@inheritDoc}
+ * @since 3.6
+ */
+ public Iterator<Neuron> iterator() {
+ return network.iterator();
+ }
+
+ /**
+ * Retrieves the underlying network.
+ * A reference is returned (enabling, for example, the network to be
+ * trained).
+ * This also implies that calling methods that modify the {@link Network}
+ * topology may cause this class to become inconsistent.
+ *
+ * @return the network.
+ */
+ public Network getNetwork() {
+ return network;
+ }
+
+ /**
+ * Gets the number of neurons in each row of this map.
+ *
+ * @return the number of rows.
+ */
+ public int getNumberOfRows() {
+ return numberOfRows;
+ }
+
+ /**
+ * Gets the number of neurons in each column of this map.
+ *
+ * @return the number of column.
+ */
+ public int getNumberOfColumns() {
+ return numberOfColumns;
+ }
+
+ /**
+ * Retrieves the neuron at location {@code (i, j)} in the map.
+ * The neuron at position {@code (0, 0)} is located at the upper-left
+ * corner of the map.
+ *
+ * @param i Row index.
+ * @param j Column index.
+ * @return the neuron at {@code (i, j)}.
+ * @throws OutOfRangeException if {@code i} or {@code j} is
+ * out of range.
+ *
+ * @see #getNeuron(int,int,HorizontalDirection,VerticalDirection)
+ */
+ public Neuron getNeuron(int i,
+ int j) {
+ if (i < 0 ||
+ i >= numberOfRows) {
+ throw new OutOfRangeException(i, 0, numberOfRows - 1);
+ }
+ if (j < 0 ||
+ j >= numberOfColumns) {
+ throw new OutOfRangeException(j, 0, numberOfColumns - 1);
+ }
+
+ return network.getNeuron(identifiers[i][j]);
+ }
+
+ /**
+ * Retrieves the neuron at {@code (location[0], location[1])} in the map.
+ * The neuron at position {@code (0, 0)} is located at the upper-left
+ * corner of the map.
+ *
+ * @param row Row index.
+ * @param col Column index.
+ * @param alongRowDir Direction along the given {@code row} (i.e. an
+ * offset will be added to the given <em>column</em> index.
+ * @param alongColDir Direction along the given {@code col} (i.e. an
+ * offset will be added to the given <em>row</em> index.
+ * @return the neuron at the requested location, or {@code null} if
+ * the location is not on the map.
+ *
+ * @see #getNeuron(int,int)
+ */
+ public Neuron getNeuron(int row,
+ int col,
+ HorizontalDirection alongRowDir,
+ VerticalDirection alongColDir) {
+ final int[] location = getLocation(row, col, alongRowDir, alongColDir);
+
+ return location == null ? null : getNeuron(location[0], location[1]);
+ }
+
+ /**
+ * Computes the location of a neighbouring neuron.
+ * It will return {@code null} if the resulting location is not part
+ * of the map.
+ * Position {@code (0, 0)} is at the upper-left corner of the map.
+ *
+ * @param row Row index.
+ * @param col Column index.
+ * @param alongRowDir Direction along the given {@code row} (i.e. an
+ * offset will be added to the given <em>column</em> index.
+ * @param alongColDir Direction along the given {@code col} (i.e. an
+ * offset will be added to the given <em>row</em> index.
+ * @return an array of length 2 containing the indices of the requested
+ * location, or {@code null} if that location is not part of the map.
+ *
+ * @see #getNeuron(int,int)
+ */
+ private int[] getLocation(int row,
+ int col,
+ HorizontalDirection alongRowDir,
+ VerticalDirection alongColDir) {
+ final int colOffset;
+ switch (alongRowDir) {
+ case LEFT:
+ colOffset = -1;
+ break;
+ case RIGHT:
+ colOffset = 1;
+ break;
+ case CENTER:
+ colOffset = 0;
+ break;
+ default:
+ // Should never happen.
+ throw new MathInternalError();
+ }
+ int colIndex = col + colOffset;
+ if (wrapColumns) {
+ if (colIndex < 0) {
+ colIndex += numberOfColumns;
+ } else {
+ colIndex %= numberOfColumns;
+ }
+ }
+
+ final int rowOffset;
+ switch (alongColDir) {
+ case UP:
+ rowOffset = -1;
+ break;
+ case DOWN:
+ rowOffset = 1;
+ break;
+ case CENTER:
+ rowOffset = 0;
+ break;
+ default:
+ // Should never happen.
+ throw new MathInternalError();
+ }
+ int rowIndex = row + rowOffset;
+ if (wrapRows) {
+ if (rowIndex < 0) {
+ rowIndex += numberOfRows;
+ } else {
+ rowIndex %= numberOfRows;
+ }
+ }
+
+ if (rowIndex < 0 ||
+ rowIndex >= numberOfRows ||
+ colIndex < 0 ||
+ colIndex >= numberOfColumns) {
+ return null;
+ } else {
+ return new int[] { rowIndex, colIndex };
+ }
+ }
+
+ /**
+ * Creates the neighbour relationships between neurons.
+ */
+ private void createLinks() {
+ // "linkEnd" will store the identifiers of the "neighbours".
+ final List<Long> linkEnd = new ArrayList<Long>();
+ final int iLast = numberOfRows - 1;
+ final int jLast = numberOfColumns - 1;
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ linkEnd.clear();
+
+ switch (neighbourhood) {
+
+ case MOORE:
+ // Add links to "diagonal" neighbours.
+ if (i > 0) {
+ if (j > 0) {
+ linkEnd.add(identifiers[i - 1][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i - 1][j + 1]);
+ }
+ }
+ if (i < iLast) {
+ if (j > 0) {
+ linkEnd.add(identifiers[i + 1][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i + 1][j + 1]);
+ }
+ }
+ if (wrapRows) {
+ if (i == 0) {
+ if (j > 0) {
+ linkEnd.add(identifiers[iLast][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[iLast][j + 1]);
+ }
+ } else if (i == iLast) {
+ if (j > 0) {
+ linkEnd.add(identifiers[0][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[0][j + 1]);
+ }
+ }
+ }
+ if (wrapColumns) {
+ if (j == 0) {
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][jLast]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][jLast]);
+ }
+ } else if (j == jLast) {
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][0]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][0]);
+ }
+ }
+ }
+ if (wrapRows &&
+ wrapColumns) {
+ if (i == 0 &&
+ j == 0) {
+ linkEnd.add(identifiers[iLast][jLast]);
+ } else if (i == 0 &&
+ j == jLast) {
+ linkEnd.add(identifiers[iLast][0]);
+ } else if (i == iLast &&
+ j == 0) {
+ linkEnd.add(identifiers[0][jLast]);
+ } else if (i == iLast &&
+ j == jLast) {
+ linkEnd.add(identifiers[0][0]);
+ }
+ }
+
+ // Case falls through since the "Moore" neighbourhood
+ // also contains the neurons that belong to the "Von
+ // Neumann" neighbourhood.
+
+ // fallthru (CheckStyle)
+ case VON_NEUMANN:
+ // Links to preceding and following "row".
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][j]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][j]);
+ }
+ if (wrapRows) {
+ if (i == 0) {
+ linkEnd.add(identifiers[iLast][j]);
+ } else if (i == iLast) {
+ linkEnd.add(identifiers[0][j]);
+ }
+ }
+
+ // Links to preceding and following "column".
+ if (j > 0) {
+ linkEnd.add(identifiers[i][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i][j + 1]);
+ }
+ if (wrapColumns) {
+ if (j == 0) {
+ linkEnd.add(identifiers[i][jLast]);
+ } else if (j == jLast) {
+ linkEnd.add(identifiers[i][0]);
+ }
+ }
+ break;
+
+ default:
+ throw new MathInternalError(); // Cannot happen.
+ }
+
+ final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
+ for (long b : linkEnd) {
+ final Neuron bNeuron = network.getNeuron(b);
+ // Link to all neighbours.
+ // The reverse links will be added as the loop proceeds.
+ network.addLink(aNeuron, bNeuron);
+ }
+ }
+ }
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ featuresList[i][j] = getNeuron(i, j).getFeatures();
+ }
+ }
+
+ return new SerializationProxy(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ featuresList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130226L;
+ /** Wrap. */
+ private final boolean wrapRows;
+ /** Wrap. */
+ private final boolean wrapColumns;
+ /** Neighbourhood type. */
+ private final SquareNeighbourhood neighbourhood;
+ /** Neurons' features. */
+ private final double[][][] featuresList;
+
+ /**
+ * @param wrapRows Whether the row dimension is wrapped.
+ * @param wrapColumns Whether the column dimension is wrapped.
+ * @param neighbourhood Neighbourhood type.
+ * @param featuresList List of neurons features.
+ * {@code neuronList}.
+ */
+ SerializationProxy(boolean wrapRows,
+ boolean wrapColumns,
+ SquareNeighbourhood neighbourhood,
+ double[][][] featuresList) {
+ this.wrapRows = wrapRows;
+ this.wrapColumns = wrapColumns;
+ this.neighbourhood = neighbourhood;
+ this.featuresList = featuresList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new NeuronSquareMesh2D(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ featuresList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java
new file mode 100644
index 0000000..41535e8
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Two-dimensional neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java
new file mode 100644
index 0000000..06cee98
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+
+/**
+ * Computes the hit histogram.
+ * Each bin will contain the number of data for which the corresponding
+ * neuron is the best matching unit.
+ * @since 3.6
+ */
+public class HitHistogram implements MapDataVisualization {
+ /** Distance. */
+ private final DistanceMeasure distance;
+ /** Whether to compute relative bin counts. */
+ private final boolean normalizeCount;
+
+ /**
+ * @param normalizeCount Whether to compute relative bin counts.
+ * If {@code true}, the data count in each bin will be divided by the total
+ * number of samples.
+ * @param distance Distance.
+ */
+ public HitHistogram(boolean normalizeCount,
+ DistanceMeasure distance) {
+ this.normalizeCount = normalizeCount;
+ this.distance = distance;
+ }
+
+ /** {@inheritDoc} */
+ public double[][] computeImage(NeuronSquareMesh2D map,
+ Iterable<double[]> data) {
+ final int nR = map.getNumberOfRows();
+ final int nC = map.getNumberOfColumns();
+
+ final LocationFinder finder = new LocationFinder(map);
+
+ // Total number of samples.
+ int numSamples = 0;
+ // Hit bins.
+ final double[][] hit = new double[nR][nC];
+
+ for (double[] sample : data) {
+ final Neuron best = MapUtils.findBest(sample, map, distance);
+
+ final LocationFinder.Location loc = finder.getLocation(best);
+ final int row = loc.getRow();
+ final int col = loc.getColumn();
+ hit[row][col] += 1;
+
+ ++numSamples;
+ }
+
+ if (normalizeCount) {
+ for (int r = 0; r < nR; r++) {
+ for (int c = 0; c < nC; c++) {
+ hit[r][c] /= numSamples;
+ }
+ }
+ }
+
+ return hit;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java
new file mode 100644
index 0000000..e4ece61
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import java.util.Map;
+import java.util.HashMap;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.exception.MathIllegalStateException;
+
+/**
+ * Helper class to find the grid coordinates of a neuron.
+ * @since 3.6
+ */
+public class LocationFinder {
+ /** Identifier to location mapping. */
+ private final Map<Long, Location> locations = new HashMap<Long, Location>();
+
+ /**
+ * Container holding a (row, column) pair.
+ */
+ public static class Location {
+ /** Row index. */
+ private final int row;
+ /** Column index. */
+ private final int column;
+
+ /**
+ * @param row Row index.
+ * @param column Column index.
+ */
+ public Location(int row,
+ int column) {
+ this.row = row;
+ this.column = column;
+ }
+
+ /**
+ * @return the row index.
+ */
+ public int getRow() {
+ return row;
+ }
+
+ /**
+ * @return the column index.
+ */
+ public int getColumn() {
+ return column;
+ }
+ }
+
+ /**
+ * Builds a finder to retrieve the locations of neurons that
+ * belong to the given {@code map}.
+ *
+ * @param map Map.
+ *
+ * @throws MathIllegalStateException if the network contains non-unique
+ * identifiers. This indicates an inconsistent state due to a bug in
+ * the construction code of the underlying
+ * {@link org.apache.commons.math3.ml.neuralnet.Network network}.
+ */
+ public LocationFinder(NeuronSquareMesh2D map) {
+ final int nR = map.getNumberOfRows();
+ final int nC = map.getNumberOfColumns();
+
+ for (int r = 0; r < nR; r++) {
+ for (int c = 0; c < nC; c++) {
+ final Long id = map.getNeuron(r, c).getIdentifier();
+ if (locations.get(id) != null) {
+ throw new MathIllegalStateException();
+ }
+ locations.put(id, new Location(r, c));
+ }
+ }
+ }
+
+ /**
+ * Retrieves a neuron's grid coordinates.
+ *
+ * @param n Neuron.
+ * @return the (row, column) coordinates of {@code n}, or {@code null}
+ * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D)
+ * map used to build this instance}.
+ */
+ public Location getLocation(Neuron n) {
+ return locations.get(n.getIdentifier());
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java
new file mode 100644
index 0000000..71fab43
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+
+/**
+ * Interface for algorithms that compute some metrics of the projection of
+ * data on a 2D-map.
+ * @since 3.6
+ */
+public interface MapDataVisualization {
+ /**
+ * Creates an image of the {@code data} metrics when represented by the
+ * {@code map}.
+ *
+ * @param map Map.
+ * @param data Data.
+ * @return a 2D-array (in row major order) representing the metrics.
+ */
+ double[][] computeImage(NeuronSquareMesh2D map,
+ Iterable<double[]> data);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java
new file mode 100644
index 0000000..9304d76
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+
+/**
+ * Interface for algorithms that compute some property of a 2D-map.
+ * @since 3.6
+ */
+public interface MapVisualization {
+ /**
+ * Creates an image of the {@code map}.
+ *
+ * @param map Map.
+ * @return a 2D-array (in row major order) representing the property.
+ */
+ double[][] computeImage(NeuronSquareMesh2D map);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java
new file mode 100644
index 0000000..8ec1da3
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+
+/**
+ * Computes the quantization error histogram.
+ * Each bin will contain the average of the distances between samples
+ * mapped to the corresponding unit and the weight vector of that unit.
+ * @since 3.6
+ */
+public class QuantizationError implements MapDataVisualization {
+ /** Distance. */
+ private final DistanceMeasure distance;
+
+ /**
+ * @param distance Distance.
+ */
+ public QuantizationError(DistanceMeasure distance) {
+ this.distance = distance;
+ }
+
+ /** {@inheritDoc} */
+ public double[][] computeImage(NeuronSquareMesh2D map,
+ Iterable<double[]> data) {
+ final int nR = map.getNumberOfRows();
+ final int nC = map.getNumberOfColumns();
+
+ final LocationFinder finder = new LocationFinder(map);
+
+ // Hit bins.
+ final int[][] hit = new int[nR][nC];
+ // Error bins.
+ final double[][] error = new double[nR][nC];
+
+ for (double[] sample : data) {
+ final Neuron best = MapUtils.findBest(sample, map, distance);
+
+ final LocationFinder.Location loc = finder.getLocation(best);
+ final int row = loc.getRow();
+ final int col = loc.getColumn();
+ hit[row][col] += 1;
+ error[row][col] += distance.compute(sample, best.getFeatures());
+ }
+
+ for (int r = 0; r < nR; r++) {
+ for (int c = 0; c < nC; c++) {
+ final int count = hit[r][c];
+ if (count != 0) {
+ error[r][c] /= count;
+ }
+ }
+ }
+
+ return error;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java
new file mode 100644
index 0000000..b8e552c
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+
+/**
+ * Visualization of high-dimensional data projection on a 2D-map.
+ * The method is described in
+ * <quote>
+ * <em>Using Smoothed Data Histograms for Cluster Visualization in Self-Organizing Maps</em>
+ * <br>
+ * by Elias Pampalk, Andreas Rauber and Dieter Merkl.
+ * </quote>
+ * @since 3.6
+ */
+public class SmoothedDataHistogram implements MapDataVisualization {
+ /** Smoothing parameter. */
+ private final int smoothingBins;
+ /** Distance. */
+ private final DistanceMeasure distance;
+ /** Normalization factor. */
+ private final double membershipNormalization;
+
+ /**
+ * @param smoothingBins Number of bins.
+ * @param distance Distance.
+ */
+ public SmoothedDataHistogram(int smoothingBins,
+ DistanceMeasure distance) {
+ this.smoothingBins = smoothingBins;
+ this.distance = distance;
+
+ double sum = 0;
+ for (int i = 0; i < smoothingBins; i++) {
+ sum += smoothingBins - i;
+ }
+
+ this.membershipNormalization = 1d / sum;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * @throws NumberIsTooSmallException if the size of the {@code map}
+ * is smaller than the number of {@link #SmoothedDataHistogram(int,DistanceMeasure)
+ * smoothing bins}.
+ */
+ public double[][] computeImage(NeuronSquareMesh2D map,
+ Iterable<double[]> data) {
+ final int nR = map.getNumberOfRows();
+ final int nC = map.getNumberOfColumns();
+
+ final int mapSize = nR * nC;
+ if (mapSize < smoothingBins) {
+ throw new NumberIsTooSmallException(mapSize, smoothingBins, true);
+ }
+
+ final LocationFinder finder = new LocationFinder(map);
+
+ // Histogram bins.
+ final double[][] histo = new double[nR][nC];
+
+ for (double[] sample : data) {
+ final Neuron[] sorted = MapUtils.sort(sample,
+ map.getNetwork(),
+ distance);
+ for (int i = 0; i < smoothingBins; i++) {
+ final LocationFinder.Location loc = finder.getLocation(sorted[i]);
+ final int row = loc.getRow();
+ final int col = loc.getColumn();
+ histo[row][col] += (smoothingBins - i) * membershipNormalization;
+ }
+ }
+
+ return histo;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java
new file mode 100644
index 0000000..b831de8
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * Computes the topographic error histogram.
+ * Each bin will contain the number of data for which the first and
+ * second best matching units are not adjacent in the map.
+ * @since 3.6
+ */
+public class TopographicErrorHistogram implements MapDataVisualization {
+ /** Distance. */
+ private final DistanceMeasure distance;
+ /** Whether to compute relative bin counts. */
+ private final boolean relativeCount;
+
+ /**
+ * @param relativeCount Whether to compute relative bin counts.
+ * If {@code true}, the data count in each bin will be divided by the total
+ * number of samples mapped to the neuron represented by that bin.
+ * @param distance Distance.
+ */
+ public TopographicErrorHistogram(boolean relativeCount,
+ DistanceMeasure distance) {
+ this.relativeCount = relativeCount;
+ this.distance = distance;
+ }
+
+ /** {@inheritDoc} */
+ public double[][] computeImage(NeuronSquareMesh2D map,
+ Iterable<double[]> data) {
+ final int nR = map.getNumberOfRows();
+ final int nC = map.getNumberOfColumns();
+
+ final Network net = map.getNetwork();
+ final LocationFinder finder = new LocationFinder(map);
+
+ // Hit bins.
+ final int[][] hit = new int[nR][nC];
+ // Error bins.
+ final double[][] error = new double[nR][nC];
+
+ for (double[] sample : data) {
+ final Pair<Neuron, Neuron> p = MapUtils.findBestAndSecondBest(sample, map, distance);
+ final Neuron best = p.getFirst();
+
+ final LocationFinder.Location loc = finder.getLocation(best);
+ final int row = loc.getRow();
+ final int col = loc.getColumn();
+ hit[row][col] += 1;
+
+ if (!net.getNeighbours(best).contains(p.getSecond())) {
+ // Increment count if first and second best matching units
+ // are not neighbours.
+ error[row][col] += 1;
+ }
+ }
+
+ if (relativeCount) {
+ for (int r = 0; r < nR; r++) {
+ for (int c = 0; c < nC; c++) {
+ error[r][c] /= hit[r][c];
+ }
+ }
+ }
+
+ return error;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java
new file mode 100644
index 0000000..aee982a
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
+
+import java.util.Collection;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+
+/**
+ * <a href="http://en.wikipedia.org/wiki/U-Matrix">U-Matrix</a>
+ * visualization of high-dimensional data projection.
+ * @since 3.6
+ */
+public class UnifiedDistanceMatrix implements MapVisualization {
+ /** Whether to show distance between each pair of neighbouring units. */
+ private final boolean individualDistances;
+ /** Distance. */
+ private final DistanceMeasure distance;
+
+ /**
+ * Simple constructor.
+ *
+ * @param individualDistances If {@code true}, the 8 individual
+ * inter-units distances will be {@link #computeImage(NeuronSquareMesh2D)
+ * computed}. They will be stored in additional pixels around each of
+ * the original units of the 2D-map. The additional pixels that lie
+ * along a "diagonal" are shared by <em>two</em> pairs of units: their
+ * value will be set to the average distance between the units belonging
+ * to each of the pairs. The value zero will be stored in the pixel
+ * corresponding to the location of a unit of the 2D-map.
+ * <br>
+ * If {@code false}, only the average distance between a unit and all its
+ * neighbours will be computed (and stored in the pixel corresponding to
+ * that unit of the 2D-map). In that case, the number of neighbours taken
+ * into account depends on the network's
+ * {@link org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood
+ * neighbourhood type}.
+ * @param distance Distance.
+ */
+ public UnifiedDistanceMatrix(boolean individualDistances,
+ DistanceMeasure distance) {
+ this.individualDistances = individualDistances;
+ this.distance = distance;
+ }
+
+ /** {@inheritDoc} */
+ public double[][] computeImage(NeuronSquareMesh2D map) {
+ if (individualDistances) {
+ return individualDistances(map);
+ } else {
+ return averageDistances(map);
+ }
+ }
+
+ /**
+ * Computes the distances between a unit of the map and its
+ * neighbours.
+ * The image will contain more pixels than the number of neurons
+ * in the given {@code map} because each neuron has 8 neighbours.
+ * The value zero will be stored in the pixels corresponding to
+ * the location of a map unit.
+ *
+ * @param map Map.
+ * @return an image representing the individual distances.
+ */
+ private double[][] individualDistances(NeuronSquareMesh2D map) {
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+
+ final double[][] uMatrix = new double[numRows * 2 + 1][numCols * 2 + 1];
+
+ // 1.
+ // Fill right and bottom slots of each unit's location with the
+ // distance between the current unit and each of the two neighbours,
+ // respectively.
+ for (int i = 0; i < numRows; i++) {
+ // Current unit's row index in result image.
+ final int iR = 2 * i + 1;
+
+ for (int j = 0; j < numCols; j++) {
+ // Current unit's column index in result image.
+ final int jR = 2 * j + 1;
+
+ final double[] current = map.getNeuron(i, j).getFeatures();
+ Neuron neighbour;
+
+ // Right neighbour.
+ neighbour = map.getNeuron(i, j,
+ NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+ NeuronSquareMesh2D.VerticalDirection.CENTER);
+ if (neighbour != null) {
+ uMatrix[iR][jR + 1] = distance.compute(current,
+ neighbour.getFeatures());
+ }
+
+ // Bottom-center neighbour.
+ neighbour = map.getNeuron(i, j,
+ NeuronSquareMesh2D.HorizontalDirection.CENTER,
+ NeuronSquareMesh2D.VerticalDirection.DOWN);
+ if (neighbour != null) {
+ uMatrix[iR + 1][jR] = distance.compute(current,
+ neighbour.getFeatures());
+ }
+ }
+ }
+
+ // 2.
+ // Fill the bottom-rigth slot of each unit's location with the average
+ // of the distances between
+ // * the current unit and its bottom-right neighbour, and
+ // * the bottom-center neighbour and the right neighbour.
+ for (int i = 0; i < numRows; i++) {
+ // Current unit's row index in result image.
+ final int iR = 2 * i + 1;
+
+ for (int j = 0; j < numCols; j++) {
+ // Current unit's column index in result image.
+ final int jR = 2 * j + 1;
+
+ final Neuron current = map.getNeuron(i, j);
+ final Neuron right = map.getNeuron(i, j,
+ NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+ NeuronSquareMesh2D.VerticalDirection.CENTER);
+ final Neuron bottom = map.getNeuron(i, j,
+ NeuronSquareMesh2D.HorizontalDirection.CENTER,
+ NeuronSquareMesh2D.VerticalDirection.DOWN);
+ final Neuron bottomRight = map.getNeuron(i, j,
+ NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+ NeuronSquareMesh2D.VerticalDirection.DOWN);
+
+ final double current2BottomRight = bottomRight == null ?
+ 0 :
+ distance.compute(current.getFeatures(),
+ bottomRight.getFeatures());
+ final double right2Bottom = (right == null ||
+ bottom == null) ?
+ 0 :
+ distance.compute(right.getFeatures(),
+ bottom.getFeatures());
+
+ // Bottom-right slot.
+ uMatrix[iR + 1][jR + 1] = 0.5 * (current2BottomRight + right2Bottom);
+ }
+ }
+
+ // 3. Copy last row into first row.
+ final int lastRow = uMatrix.length - 1;
+ uMatrix[0] = uMatrix[lastRow];
+
+ // 4.
+ // Copy last column into first column.
+ final int lastCol = uMatrix[0].length - 1;
+ for (int r = 0; r < lastRow; r++) {
+ uMatrix[r][0] = uMatrix[r][lastCol];
+ }
+
+ return uMatrix;
+ }
+
+ /**
+ * Computes the distances between a unit of the map and its neighbours.
+ *
+ * @param map Map.
+ * @return an image representing the average distances.
+ */
+ private double[][] averageDistances(NeuronSquareMesh2D map) {
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+ final double[][] uMatrix = new double[numRows][numCols];
+
+ final Network net = map.getNetwork();
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final Neuron neuron = map.getNeuron(i, j);
+ final Collection<Neuron> neighbours = net.getNeighbours(neuron);
+ final double[] features = neuron.getFeatures();
+
+ double d = 0;
+ int count = 0;
+ for (Neuron n : neighbours) {
+ ++count;
+ d += distance.compute(features, n.getFeatures());
+ }
+
+ uMatrix[i][j] = d / count;
+ }
+ }
+
+ return uMatrix;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java
new file mode 100644
index 0000000..cd4aab0
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities to visualize two-dimensional neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod.util;
diff --git a/src/main/java/org/apache/commons/math3/ml/package-info.java b/src/main/java/org/apache/commons/math3/ml/package-info.java
new file mode 100644
index 0000000..394aad2
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/package-info.java
@@ -0,0 +1,18 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/** Base package for machine learning algorithms. */
+package org.apache.commons.math3.ml;