diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/ml')
52 files changed, 5983 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java new file mode 100644 index 0000000..5cfc7bc --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +/** + * A Cluster used by centroid-based clustering algorithms. + * <p> + * Defines additionally a cluster center which may not necessarily be a member + * of the original data set. + * + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public class CentroidCluster<T extends Clusterable> extends Cluster<T> { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3075288519071812288L; + + /** Center of the cluster. */ + private final Clusterable center; + + /** + * Build a cluster centered at a specified point. + * @param center the point which is to be the center of this cluster + */ + public CentroidCluster(final Clusterable center) { + super(); + this.center = center; + } + + /** + * Get the point chosen to be the center of this cluster. + * @return chosen cluster center + */ + public Clusterable getCenter() { + return center; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java new file mode 100644 index 0000000..fa6df94 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Cluster holding a set of {@link Clusterable} points. + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public class Cluster<T extends Clusterable> implements Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3442297081515880464L; + + /** The points contained in this cluster. */ + private final List<T> points; + + /** + * Build a cluster centered at a specified point. + */ + public Cluster() { + points = new ArrayList<T>(); + } + + /** + * Add a point to this cluster. + * @param point point to add + */ + public void addPoint(final T point) { + points.add(point); + } + + /** + * Get the points contained in the cluster. + * @return points contained in the cluster + */ + public List<T> getPoints() { + return points; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java new file mode 100644 index 0000000..e712eb7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +/** + * Interface for n-dimensional points that can be clustered together. + * @since 3.2 + */ +public interface Clusterable { + + /** + * Gets the n-dimensional point. + * + * @return the point array + */ + double[] getPoint(); +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java new file mode 100644 index 0000000..30e38c6 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * Base class for clustering algorithms. + * + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public abstract class Clusterer<T extends Clusterable> { + + /** The distance measure to use. */ + private DistanceMeasure measure; + + /** + * Build a new clusterer with the given {@link DistanceMeasure}. + * + * @param measure the distance measure to use + */ + protected Clusterer(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Perform a cluster analysis on the given set of {@link Clusterable} instances. + * + * @param points the set of {@link Clusterable} instances + * @return a {@link List} of clusters + * @throws MathIllegalArgumentException if points are null or the number of + * data points is not compatible with this clusterer + * @throws ConvergenceException if the algorithm has not yet converged after + * the maximum number of iterations has been exceeded + */ + public abstract List<? extends Cluster<T>> cluster(Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException; + + /** + * Returns the {@link DistanceMeasure} instance used by this clusterer. + * + * @return the distance measure + */ + public DistanceMeasure getDistanceMeasure() { + return measure; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java new file mode 100644 index 0000000..ce3d5cd --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.util.MathUtils; + +/** + * DBSCAN (density-based spatial clustering of applications with noise) algorithm. + * <p> + * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e. + * a point p is density connected to another point q, if there exists a chain of + * points p<sub>i</sub>, with i = 1 .. n and p<sub>1</sub> = p and p<sub>n</sub> = q, + * such that each pair <p<sub>i</sub>, p<sub>i+1</sub>> is directly density-reachable. + * A point q is directly density-reachable from point p if it is in the ε-neighborhood + * of this point. + * <p> + * Any point that is not density-reachable from a formed cluster is treated as noise, and + * will thus not be present in the result. + * <p> + * The algorithm requires two parameters: + * <ul> + * <li>eps: the distance that defines the ε-neighborhood of a point + * <li>minPoints: the minimum number of density-connected points required to form a cluster + * </ul> + * + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/DBSCAN">DBSCAN (wikipedia)</a> + * @see <a href="http://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf"> + * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise</a> + * @since 3.2 + */ +public class DBSCANClusterer<T extends Clusterable> extends Clusterer<T> { + + /** Maximum radius of the neighborhood to be considered. */ + private final double eps; + + /** Minimum number of points needed for a cluster. */ + private final int minPts; + + /** Status of a point during the clustering process. */ + private enum PointStatus { + /** The point has is considered to be noise. */ + NOISE, + /** The point is already part of a cluster. */ + PART_OF_CLUSTER + } + + /** + * Creates a new instance of a DBSCANClusterer. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts) + throws NotPositiveException { + this(eps, minPts, new EuclideanDistance()); + } + + /** + * Creates a new instance of a DBSCANClusterer. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @param measure the distance measure to use + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure) + throws NotPositiveException { + super(measure); + + if (eps < 0.0d) { + throw new NotPositiveException(eps); + } + if (minPts < 0) { + throw new NotPositiveException(minPts); + } + this.eps = eps; + this.minPts = minPts; + } + + /** + * Returns the maximum radius of the neighborhood to be considered. + * @return maximum radius of the neighborhood + */ + public double getEps() { + return eps; + } + + /** + * Returns the minimum number of points needed for a cluster. + * @return minimum number of points needed for a cluster + */ + public int getMinPts() { + return minPts; + } + + /** + * Performs DBSCAN cluster analysis. + * + * @param points the points to cluster + * @return the list of clusters + * @throws NullArgumentException if the data points are null + */ + @Override + public List<Cluster<T>> cluster(final Collection<T> points) throws NullArgumentException { + + // sanity checks + MathUtils.checkNotNull(points); + + final List<Cluster<T>> clusters = new ArrayList<Cluster<T>>(); + final Map<Clusterable, PointStatus> visited = new HashMap<Clusterable, PointStatus>(); + + for (final T point : points) { + if (visited.get(point) != null) { + continue; + } + final List<T> neighbors = getNeighbors(point, points); + if (neighbors.size() >= minPts) { + // DBSCAN does not care about center points + final Cluster<T> cluster = new Cluster<T>(); + clusters.add(expandCluster(cluster, point, neighbors, points, visited)); + } else { + visited.put(point, PointStatus.NOISE); + } + } + + return clusters; + } + + /** + * Expands the cluster to include density-reachable items. + * + * @param cluster Cluster to expand + * @param point Point to add to cluster + * @param neighbors List of neighbors + * @param points the data set + * @param visited the set of already visited points + * @return the expanded cluster + */ + private Cluster<T> expandCluster(final Cluster<T> cluster, + final T point, + final List<T> neighbors, + final Collection<T> points, + final Map<Clusterable, PointStatus> visited) { + cluster.addPoint(point); + visited.put(point, PointStatus.PART_OF_CLUSTER); + + List<T> seeds = new ArrayList<T>(neighbors); + int index = 0; + while (index < seeds.size()) { + final T current = seeds.get(index); + PointStatus pStatus = visited.get(current); + // only check non-visited points + if (pStatus == null) { + final List<T> currentNeighbors = getNeighbors(current, points); + if (currentNeighbors.size() >= minPts) { + seeds = merge(seeds, currentNeighbors); + } + } + + if (pStatus != PointStatus.PART_OF_CLUSTER) { + visited.put(current, PointStatus.PART_OF_CLUSTER); + cluster.addPoint(current); + } + + index++; + } + return cluster; + } + + /** + * Returns a list of density-reachable neighbors of a {@code point}. + * + * @param point the point to look for + * @param points possible neighbors + * @return the List of neighbors + */ + private List<T> getNeighbors(final T point, final Collection<T> points) { + final List<T> neighbors = new ArrayList<T>(); + for (final T neighbor : points) { + if (point != neighbor && distance(neighbor, point) <= eps) { + neighbors.add(neighbor); + } + } + return neighbors; + } + + /** + * Merges two lists together. + * + * @param one first list + * @param two second list + * @return merged lists + */ + private List<T> merge(final List<T> one, final List<T> two) { + final Set<T> oneSet = new HashSet<T>(one); + for (T item : two) { + if (!oneSet.contains(item)) { + one.add(item); + } + } + return one; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java new file mode 100644 index 0000000..4fb31f7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * A simple implementation of {@link Clusterable} for points with double coordinates. + * @since 3.2 + */ +public class DoublePoint implements Clusterable, Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 3946024775784901369L; + + /** Point coordinates. */ + private final double[] point; + + /** + * Build an instance wrapping an double array. + * <p> + * The wrapped array is referenced, it is <em>not</em> copied. + * + * @param point the n-dimensional point in double space + */ + public DoublePoint(final double[] point) { + this.point = point; + } + + /** + * Build an instance wrapping an integer array. + * <p> + * The wrapped array is copied to an internal double array. + * + * @param point the n-dimensional point in integer space + */ + public DoublePoint(final int[] point) { + this.point = new double[point.length]; + for ( int i = 0; i < point.length; i++) { + this.point[i] = point[i]; + } + } + + /** {@inheritDoc} */ + public double[] getPoint() { + return point; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(final Object other) { + if (!(other instanceof DoublePoint)) { + return false; + } + return Arrays.equals(point, ((DoublePoint) other).point); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(point); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return Arrays.toString(point); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java new file mode 100644 index 0000000..5f89934 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; + +/** + * Fuzzy K-Means clustering algorithm. + * <p> + * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the + * major difference that a single data point is not uniquely assigned to a single cluster. + * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership + * to the cluster j. + * <p> + * The algorithm then tries to minimize the objective function: + * <pre> + * J = ∑<sub>i=1..C</sub>∑<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup> + * </pre> + * with d<sub>ik</sub> being the distance between data point i and the cluster center k. + * <p> + * The algorithm requires two parameters: + * <ul> + * <li>k: the number of clusters + * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters + * </ul> + * Additional, optional parameters: + * <ul> + * <li>maxIterations: the maximum number of iterations + * <li>epsilon: the convergence criteria, default is 1e-3 + * </ul> + * <p> + * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection + * of the initial cluster centers. + * + * @param <T> type of the points to cluster + * @since 3.3 + */ +public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> { + + /** The default value for the convergence criteria. */ + private static final double DEFAULT_EPSILON = 1e-3; + + /** The number of clusters. */ + private final int k; + + /** The maximum number of iterations. */ + private final int maxIterations; + + /** The fuzziness factor. */ + private final double fuzziness; + + /** The convergence criteria. */ + private final double epsilon; + + /** Random generator for choosing initial centers. */ + private final RandomGenerator random; + + /** The membership matrix. */ + private double[][] membershipMatrix; + + /** The list of points used in the last call to {@link #cluster(Collection)}. */ + private List<T> points; + + /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */ + private List<CentroidCluster<T>> clusters; + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException { + this(k, fuzziness, -1, new EuclideanDistance()); + } + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness, + final int maxIterations, final DistanceMeasure measure) + throws NumberIsTooSmallException { + this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator()); + } + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param epsilon the convergence criteria (default is 1e-3) + * @param random random generator to use for choosing initial centers + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness, + final int maxIterations, final DistanceMeasure measure, + final double epsilon, final RandomGenerator random) + throws NumberIsTooSmallException { + + super(measure); + + if (fuzziness <= 1.0d) { + throw new NumberIsTooSmallException(fuzziness, 1.0, false); + } + this.k = k; + this.fuzziness = fuzziness; + this.maxIterations = maxIterations; + this.epsilon = epsilon; + this.random = random; + + this.membershipMatrix = null; + this.points = null; + this.clusters = null; + } + + /** + * Return the number of clusters this instance will use. + * @return the number of clusters + */ + public int getK() { + return k; + } + + /** + * Returns the fuzziness factor used by this instance. + * @return the fuzziness factor + */ + public double getFuzziness() { + return fuzziness; + } + + /** + * Returns the maximum number of iterations this instance will use. + * @return the maximum number of iterations, or -1 if no maximum is set + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Returns the convergence criteria used by this instance. + * @return the convergence criteria + */ + public double getEpsilon() { + return epsilon; + } + + /** + * Returns the random generator this instance will use. + * @return the random generator + */ + public RandomGenerator getRandomGenerator() { + return random; + } + + /** + * Returns the {@code nxk} membership matrix, where {@code n} is the number + * of data points and {@code k} the number of clusters. + * <p> + * The element U<sub>i,j</sub> represents the membership value for data point {@code i} + * to cluster {@code j}. + * + * @return the membership matrix + * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before + */ + public RealMatrix getMembershipMatrix() { + if (membershipMatrix == null) { + throw new MathIllegalStateException(); + } + return MatrixUtils.createRealMatrix(membershipMatrix); + } + + /** + * Returns an unmodifiable list of the data points used in the last + * call to {@link #cluster(Collection)}. + * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has + * not been called before. + */ + public List<T> getDataPoints() { + return points; + } + + /** + * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}. + * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has + * not been called before. + */ + public List<CentroidCluster<T>> getClusters() { + return clusters; + } + + /** + * Get the value of the objective function. + * @return the objective function evaluation as double value + * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before + */ + public double getObjectiveFunctionValue() { + if (points == null || clusters == null) { + throw new MathIllegalStateException(); + } + + int i = 0; + double objFunction = 0.0; + for (final T point : points) { + int j = 0; + for (final CentroidCluster<T> cluster : clusters) { + final double dist = distance(point, cluster.getCenter()); + objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness); + j++; + } + i++; + } + return objFunction; + } + + /** + * Performs Fuzzy K-Means cluster analysis. + * + * @param dataPoints the points to cluster + * @return the list of clusters + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) + throws MathIllegalArgumentException { + + // sanity checks + MathUtils.checkNotNull(dataPoints); + + final int size = dataPoints.size(); + + // number of clusters has to be smaller or equal the number of data points + if (size < k) { + throw new NumberIsTooSmallException(size, k, false); + } + + // copy the input collection to an unmodifiable list with indexed access + points = Collections.unmodifiableList(new ArrayList<T>(dataPoints)); + clusters = new ArrayList<CentroidCluster<T>>(); + membershipMatrix = new double[size][k]; + final double[][] oldMatrix = new double[size][k]; + + // if no points are provided, return an empty list of clusters + if (size == 0) { + return clusters; + } + + initializeMembershipMatrix(); + + // there is at least one point + final int pointDimension = points.get(0).getPoint().length; + for (int i = 0; i < k; i++) { + clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension]))); + } + + int iteration = 0; + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + double difference = 0.0; + + do { + saveMembershipMatrix(oldMatrix); + updateClusterCenters(); + updateMembershipMatrix(); + difference = calculateMaxMembershipChange(oldMatrix); + } while (difference > epsilon && ++iteration < max); + + return clusters; + } + + /** + * Update the cluster centers. + */ + private void updateClusterCenters() { + int j = 0; + final List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(k); + for (final CentroidCluster<T> cluster : clusters) { + final Clusterable center = cluster.getCenter(); + int i = 0; + double[] arr = new double[center.getPoint().length]; + double sum = 0.0; + for (final T point : points) { + final double u = FastMath.pow(membershipMatrix[i][j], fuzziness); + final double[] pointArr = point.getPoint(); + for (int idx = 0; idx < arr.length; idx++) { + arr[idx] += u * pointArr[idx]; + } + sum += u; + i++; + } + MathArrays.scaleInPlace(1.0 / sum, arr); + newClusters.add(new CentroidCluster<T>(new DoublePoint(arr))); + j++; + } + clusters.clear(); + clusters = newClusters; + } + + /** + * Updates the membership matrix and assigns the points to the cluster with + * the highest membership. + */ + private void updateMembershipMatrix() { + for (int i = 0; i < points.size(); i++) { + final T point = points.get(i); + double maxMembership = Double.MIN_VALUE; + int newCluster = -1; + for (int j = 0; j < clusters.size(); j++) { + double sum = 0.0; + final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter())); + + if (distA != 0.0) { + for (final CentroidCluster<T> c : clusters) { + final double distB = FastMath.abs(distance(point, c.getCenter())); + if (distB == 0.0) { + sum = Double.POSITIVE_INFINITY; + break; + } + sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); + } + } + + double membership; + if (sum == 0.0) { + membership = 1.0; + } else if (sum == Double.POSITIVE_INFINITY) { + membership = 0.0; + } else { + membership = 1.0 / sum; + } + membershipMatrix[i][j] = membership; + + if (membershipMatrix[i][j] > maxMembership) { + maxMembership = membershipMatrix[i][j]; + newCluster = j; + } + } + clusters.get(newCluster).addPoint(point); + } + } + + /** + * Initialize the membership matrix with random values. + */ + private void initializeMembershipMatrix() { + for (int i = 0; i < points.size(); i++) { + for (int j = 0; j < k; j++) { + membershipMatrix[i][j] = random.nextDouble(); + } + membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0); + } + } + + /** + * Calculate the maximum element-by-element change of the membership matrix + * for the current iteration. + * + * @param matrix the membership matrix of the previous iteration + * @return the maximum membership matrix change + */ + private double calculateMaxMembershipChange(final double[][] matrix) { + double maxMembership = 0.0; + for (int i = 0; i < points.size(); i++) { + for (int j = 0; j < clusters.size(); j++) { + double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]); + maxMembership = FastMath.max(v, maxMembership); + } + } + return maxMembership; + } + + /** + * Copy the membership matrix into the provided matrix. + * + * @param matrix the place to store the membership matrix + */ + private void saveMembershipMatrix(final double[][] matrix) { + for (int i = 0; i < points.size(); i++) { + System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size()); + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java new file mode 100644 index 0000000..2e57fac --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.util.MathUtils; + +/** + * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> + * @since 3.2 + */ +public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { + + /** Strategies to use for replacing an empty cluster. */ + public enum EmptyClusterStrategy { + + /** Split the cluster with largest distance variance. */ + LARGEST_VARIANCE, + + /** Split the cluster with largest number of points. */ + LARGEST_POINTS_NUMBER, + + /** Create a cluster around the point farthest from its centroid. */ + FARTHEST_POINT, + + /** Generate an error. */ + ERROR + + } + + /** The number of clusters. */ + private final int k; + + /** The maximum number of iterations. */ + private final int maxIterations; + + /** Random generator for choosing initial centers. */ + private final RandomGenerator random; + + /** Selected strategy for empty clusters. */ + private final EmptyClusterStrategy emptyStrategy; + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + */ + public KMeansPlusPlusClusterer(final int k) { + this(k, -1); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations) { + this(k, maxIterations, new EuclideanDistance()); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { + this(k, maxIterations, measure, new JDKRandomGenerator()); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random) { + this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); + } + + /** Build a clusterer. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + * @param emptyStrategy strategy to use for handling empty clusters that + * may appear during algorithm iterations + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random, + final EmptyClusterStrategy emptyStrategy) { + super(measure); + this.k = k; + this.maxIterations = maxIterations; + this.random = random; + this.emptyStrategy = emptyStrategy; + } + + /** + * Return the number of clusters this instance will use. + * @return the number of clusters + */ + public int getK() { + return k; + } + + /** + * Returns the maximum number of iterations this instance will use. + * @return the maximum number of iterations, or -1 if no maximum is set + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Returns the random generator this instance will use. + * @return the random generator + */ + public RandomGenerator getRandomGenerator() { + return random; + } + + /** + * Returns the {@link EmptyClusterStrategy} used by this instance. + * @return the {@link EmptyClusterStrategy} + */ + public EmptyClusterStrategy getEmptyClusterStrategy() { + return emptyStrategy; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException { + + // sanity checks + MathUtils.checkNotNull(points); + + // number of clusters has to be smaller or equal the number of data points + if (points.size() < k) { + throw new NumberIsTooSmallException(points.size(), k, false); + } + + // create the initial clusters + List<CentroidCluster<T>> clusters = chooseInitialCenters(points); + + // create an array containing the latest assignment of a point to a cluster + // no need to initialize the array, as it will be filled with the first assignment + int[] assignments = new int[points.size()]; + assignPointsToClusters(clusters, points, assignments); + + // iterate through updating the centers until we're done + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + for (int count = 0; count < max; count++) { + boolean emptyCluster = false; + List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(); + for (final CentroidCluster<T> cluster : clusters) { + final Clusterable newCenter; + if (cluster.getPoints().isEmpty()) { + switch (emptyStrategy) { + case LARGEST_VARIANCE : + newCenter = getPointFromLargestVarianceCluster(clusters); + break; + case LARGEST_POINTS_NUMBER : + newCenter = getPointFromLargestNumberCluster(clusters); + break; + case FARTHEST_POINT : + newCenter = getFarthestPoint(clusters); + break; + default : + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + emptyCluster = true; + } else { + newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); + } + newClusters.add(new CentroidCluster<T>(newCenter)); + } + int changes = assignPointsToClusters(newClusters, points, assignments); + clusters = newClusters; + + // if there were no more changes in the point-to-cluster assignment + // and there are no empty clusters left, return the current clusters + if (changes == 0 && !emptyCluster) { + return clusters; + } + } + return clusters; + } + + /** + * Adds the given points to the closest {@link Cluster}. + * + * @param clusters the {@link Cluster}s to add the points to + * @param points the points to add to the given {@link Cluster}s + * @param assignments points assignments to clusters + * @return the number of points assigned to different clusters as the iteration before + */ + private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, + final Collection<T> points, + final int[] assignments) { + int assignedDifferently = 0; + int pointIndex = 0; + for (final T p : points) { + int clusterIndex = getNearestCluster(clusters, p); + if (clusterIndex != assignments[pointIndex]) { + assignedDifferently++; + } + + CentroidCluster<T> cluster = clusters.get(clusterIndex); + cluster.addPoint(p); + assignments[pointIndex++] = clusterIndex; + } + + return assignedDifferently; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @return the initial centers + */ + private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { + + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster<T>(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d*d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster<T> (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) + throws ConvergenceException { + + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster<T> selected = null; + for (final CentroidCluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final Clusterable center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get a random point from the {@link Cluster} with the largest number of points + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) + throws ConvergenceException { + + int maxNumber = 0; + Cluster<T> selected = null; + for (final Cluster<T> cluster : clusters) { + + // get the number of points of the current cluster + final int number = cluster.getPoints().size(); + + // select the cluster with the largest number of points + if (number > maxNumber) { + maxNumber = number; + selected = cluster; + } + + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get the point farthest to its cluster center + * + * @param clusters the {@link Cluster}s to search + * @return point farthest to its cluster center + * @throws ConvergenceException if clusters are all empty + */ + private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException { + + double maxDistance = Double.NEGATIVE_INFINITY; + Cluster<T> selectedCluster = null; + int selectedPoint = -1; + for (final CentroidCluster<T> cluster : clusters) { + + // get the farthest point + final Clusterable center = cluster.getCenter(); + final List<T> points = cluster.getPoints(); + for (int i = 0; i < points.size(); ++i) { + final double distance = distance(points.get(i), center); + if (distance > maxDistance) { + maxDistance = distance; + selectedCluster = cluster; + selectedPoint = i; + } + } + + } + + // did we find at least one non-empty cluster ? + if (selectedCluster == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + return selectedCluster.getPoints().remove(selectedPoint); + + } + + /** + * Returns the nearest {@link Cluster} to the given point + * + * @param clusters the {@link Cluster}s to search + * @param point the point to find the nearest {@link Cluster} for + * @return the index of the nearest {@link Cluster} to the given point + */ + private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) { + double minDistance = Double.MAX_VALUE; + int clusterIndex = 0; + int minCluster = 0; + for (final CentroidCluster<T> c : clusters) { + final double distance = distance(point, c.getCenter()); + if (distance < minDistance) { + minDistance = distance; + minCluster = clusterIndex; + } + clusterIndex++; + } + return minCluster; + } + + /** + * Computes the centroid for a set of points. + * + * @param points the set of points + * @param dimension the point dimension + * @return the computed centroid for the set of points + */ + private Clusterable centroidOf(final Collection<T> points, final int dimension) { + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java new file mode 100644 index 0000000..796fc7a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator; +import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances; + +/** + * A wrapper around a k-means++ clustering algorithm which performs multiple trials + * and returns the best solution. + * @param <T> type of the points to cluster + * @since 3.2 + */ +public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { + + /** The underlying k-means clusterer. */ + private final KMeansPlusPlusClusterer<T> clusterer; + + /** The number of trial runs. */ + private final int numTrials; + + /** The cluster evaluator to use. */ + private final ClusterEvaluator<T> evaluator; + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, + final int numTrials) { + this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())); + } + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + * @param evaluator the cluster evaluator to use + * @since 3.3 + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, + final int numTrials, + final ClusterEvaluator<T> evaluator) { + super(clusterer.getDistanceMeasure()); + this.clusterer = clusterer; + this.numTrials = numTrials; + this.evaluator = evaluator; + } + + /** + * Returns the embedded k-means clusterer used by this instance. + * @return the embedded clusterer + */ + public KMeansPlusPlusClusterer<T> getClusterer() { + return clusterer; + } + + /** + * Returns the number of trials this instance will do. + * @return the number of trials + */ + public int getNumTrials() { + return numTrials; + } + + /** + * Returns the {@link ClusterEvaluator} used to determine the "best" clustering. + * @return the used {@link ClusterEvaluator} + * @since 3.3 + */ + public ClusterEvaluator<T> getClusterEvaluator() { + return evaluator; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * underlying {@link KMeansPlusPlusClusterer} has its + * {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}. + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException { + + // at first, we have not found any clusters list yet + List<CentroidCluster<T>> best = null; + double bestVarianceSum = Double.POSITIVE_INFINITY; + + // do several clustering trials + for (int i = 0; i < numTrials; ++i) { + + // compute a clusters list + List<CentroidCluster<T>> clusters = clusterer.cluster(points); + + // compute the variance of the current list + final double varianceSum = evaluator.score(clusters); + + if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) { + // this one is the best we have found so far, remember it + best = clusters; + bestVarianceSum = varianceSum; + } + + } + + // return the best clusters list found + return best; + + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java new file mode 100644 index 0000000..2bb8ba3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; + +/** + * Base class for cluster evaluation methods. + * + * @param <T> type of the clustered points + * @since 3.3 + */ +public abstract class ClusterEvaluator<T extends Clusterable> { + + /** The distance measure to use when evaluating the cluster. */ + private final DistanceMeasure measure; + + /** + * Creates a new cluster evaluator with an {@link EuclideanDistance} + * as distance measure. + */ + public ClusterEvaluator() { + this(new EuclideanDistance()); + } + + /** + * Creates a new cluster evaluator with the given distance measure. + * @param measure the distance measure to use + */ + public ClusterEvaluator(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Computes the evaluation score for the given list of clusters. + * @param clusters the clusters to evaluate + * @return the computed score + */ + public abstract double score(List<? extends Cluster<T>> clusters); + + /** + * Returns whether the first evaluation score is considered to be better + * than the second one by this evaluator. + * <p> + * Specific implementations shall override this method if the returned scores + * do not follow the same ordering, i.e. smaller score is better. + * + * @param score1 the first score + * @param score2 the second score + * @return {@code true} if the first score is considered to be better, {@code false} otherwise + */ + public boolean isBetterScore(double score1, double score2) { + return score1 < score2; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + + /** + * Computes the centroid for a cluster. + * + * @param cluster the cluster + * @return the computed centroid for the cluster, + * or {@code null} if the cluster does not contain any points + */ + protected Clusterable centroidOf(final Cluster<T> cluster) { + final List<T> points = cluster.getPoints(); + if (points.isEmpty()) { + return null; + } + + // in case the cluster is of type CentroidCluster, no need to compute the centroid + if (cluster instanceof CentroidCluster) { + return ((CentroidCluster<T>) cluster).getCenter(); + } + + final int dimension = points.get(0).getPoint().length; + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java new file mode 100644 index 0000000..b5b249c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.stat.descriptive.moment.Variance; + +/** + * Computes the sum of intra-cluster distance variances according to the formula: + * <pre> + * \( score = \sum\limits_{i=1}^n \sigma_i^2 \) + * </pre> + * where n is the number of clusters and \( \sigma_i^2 \) is the variance of + * intra-cluster distances of cluster \( c_i \). + * + * @param <T> the type of the clustered points + * @since 3.3 + */ +public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T> { + + /** + * + * @param measure the distance measure to use + */ + public SumOfClusterVariances(final DistanceMeasure measure) { + super(measure); + } + + /** {@inheritDoc} */ + @Override + public double score(final List<? extends Cluster<T>> clusters) { + double varianceSum = 0.0; + for (final Cluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + final Clusterable center = centroidOf(cluster); + + // compute the distance variance of the current cluster + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + varianceSum += stat.getResult(); + + } + } + return varianceSum; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java new file mode 100644 index 0000000..700f566 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Cluster evaluation methods. + */ +package org.apache.commons.math3.ml.clustering.evaluation; diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java new file mode 100644 index 0000000..02f1d20 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Clustering algorithms. + */ +package org.apache.commons.math3.ml.clustering; diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java new file mode 100644 index 0000000..d467c3b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the Canberra distance between two points. + * + * @since 3.2 + */ +public class CanberraDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -6972277381587032228L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + MathArrays.checkEqualLength(a, b); + double sum = 0; + for (int i = 0; i < a.length; i++) { + final double num = FastMath.abs(a[i] - b[i]); + final double denom = FastMath.abs(a[i]) + FastMath.abs(b[i]); + sum += num == 0.0 && denom == 0.0 ? 0.0 : num / denom; + } + return sum; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java new file mode 100644 index 0000000..05dccb5 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L<sub>∞</sub> (max of abs) distance between two points. + * + * @since 3.2 + */ +public class ChebyshevDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -4694868171115238296L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + return MathArrays.distanceInf(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java new file mode 100644 index 0000000..ff9c27f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import java.io.Serializable; + +import org.apache.commons.math3.exception.DimensionMismatchException; + +/** + * Interface for distance measures of n-dimensional vectors. + * + * @since 3.2 + */ +public interface DistanceMeasure extends Serializable { + + /** + * Compute the distance between two n-dimensional vectors. + * <p> + * The two vectors are required to have the same dimension. + * + * @param a the first vector + * @param b the second vector + * @return the distance between the two vectors + * @throws DimensionMismatchException if the array lengths differ. + */ + double compute(double[] a, double[] b) throws DimensionMismatchException; +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java new file mode 100644 index 0000000..2518624 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions. + * + * @see <a href="http://en.wikipedia.org/wiki/Earth_mover's_distance">Earth Mover's distance (Wikipedia)</a> + * + * @since 3.3 + */ +public class EarthMoversDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -5406732779747414922L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + MathArrays.checkEqualLength(a, b); + double lastDistance = 0; + double totalDistance = 0; + for (int i = 0; i < a.length; i++) { + final double currentDistance = (a[i] + lastDistance) - b[i]; + totalDistance += FastMath.abs(currentDistance); + lastDistance = currentDistance; + } + return totalDistance; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java new file mode 100644 index 0000000..187badc --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L<sub>2</sub> (Euclidean) distance between two points. + * + * @since 3.2 + */ +public class EuclideanDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 1717556319784040040L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + return MathArrays.distance(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java new file mode 100644 index 0000000..2eebe1b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.distance; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.MathArrays; + +/** + * Calculates the L<sub>1</sub> (sum of abs) distance between two points. + * + * @since 3.2 + */ +public class ManhattanDistance implements DistanceMeasure { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -9108154600539125566L; + + /** {@inheritDoc} */ + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + return MathArrays.distance1(a, b); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/distance/package-info.java b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java new file mode 100644 index 0000000..f6d124a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/distance/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Common distance measures. + */ +package org.apache.commons.math3.ml.distance; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java new file mode 100644 index 0000000..1f48d45 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +/** + * Defines how to assign the first value of a neuron's feature. + * + * @since 3.3 + */ +public interface FeatureInitializer { + /** + * Selects the initial value. + * + * @return the initial value. + */ + double value(); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java new file mode 100644 index 0000000..f5569b1 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.function.Constant; +import org.apache.commons.math3.random.RandomGenerator; + +/** + * Creates functions that will select the initial values of a neuron's + * features. + * + * @since 3.3 + */ +public class FeatureInitializerFactory { + /** Class contains only static methods. */ + private FeatureInitializerFactory() {} + + /** + * Uniform sampling of the given range. + * + * @param min Lower bound of the range. + * @param max Upper bound of the range. + * @param rng Random number generator used to draw samples from a + * uniform distribution. + * @return an initializer such that the features will be initialized with + * values within the given range. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code min >= max}. + */ + public static FeatureInitializer uniform(final RandomGenerator rng, + final double min, + final double max) { + return randomize(new UniformRealDistribution(rng, min, max), + function(new Constant(0), 0, 0)); + } + + /** + * Uniform sampling of the given range. + * + * @param min Lower bound of the range. + * @param max Upper bound of the range. + * @return an initializer such that the features will be initialized with + * values within the given range. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code min >= max}. + */ + public static FeatureInitializer uniform(final double min, + final double max) { + return randomize(new UniformRealDistribution(min, max), + function(new Constant(0), 0, 0)); + } + + /** + * Creates an initializer from a univariate function {@code f(x)}. + * The argument {@code x} is set to {@code init} at the first call + * and will be incremented at each call. + * + * @param f Function. + * @param init Initial value. + * @param inc Increment + * @return the initializer. + */ + public static FeatureInitializer function(final UnivariateFunction f, + final double init, + final double inc) { + return new FeatureInitializer() { + /** Argument. */ + private double arg = init; + + /** {@inheritDoc} */ + public double value() { + final double result = f.value(arg); + arg += inc; + return result; + } + }; + } + + /** + * Adds some amount of random data to the given initializer. + * + * @param random Random variable distribution. + * @param orig Original initializer. + * @return an initializer whose {@link FeatureInitializer#value() value} + * method will return {@code orig.value() + random.sample()}. + */ + public static FeatureInitializer randomize(final RealDistribution random, + final FeatureInitializer orig) { + return new FeatureInitializer() { + /** {@inheritDoc} */ + public double value() { + return orig.value() + random.sample(); + } + }; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java new file mode 100644 index 0000000..0b7a675 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Comparator; + +import org.apache.commons.math3.exception.NoDataException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.util.Pair; + +/** + * Utilities for network maps. + * + * @since 3.3 + */ +public class MapUtils { + /** + * Class contains only static methods. + */ + private MapUtils() {} + + /** + * Finds the neuron that best matches the given features. + * + * @param features Data. + * @param neurons List of neurons to scan. If the list is empty + * {@code null} will be returned. + * @param distance Distance function. The neuron's features are + * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}. + * @return the neuron whose features are closest to the given data. + * @throws org.apache.commons.math3.exception.DimensionMismatchException + * if the size of the input is not compatible with the neurons features + * size. + */ + public static Neuron findBest(double[] features, + Iterable<Neuron> neurons, + DistanceMeasure distance) { + Neuron best = null; + double min = Double.POSITIVE_INFINITY; + for (final Neuron n : neurons) { + final double d = distance.compute(n.getFeatures(), features); + if (d < min) { + min = d; + best = n; + } + } + + return best; + } + + /** + * Finds the two neurons that best match the given features. + * + * @param features Data. + * @param neurons List of neurons to scan. If the list is empty + * {@code null} will be returned. + * @param distance Distance function. The neuron's features are + * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}. + * @return the two neurons whose features are closest to the given data. + * @throws org.apache.commons.math3.exception.DimensionMismatchException + * if the size of the input is not compatible with the neurons features + * size. + */ + public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features, + Iterable<Neuron> neurons, + DistanceMeasure distance) { + Neuron[] best = { null, null }; + double[] min = { Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY }; + for (final Neuron n : neurons) { + final double d = distance.compute(n.getFeatures(), features); + if (d < min[0]) { + // Replace second best with old best. + min[1] = min[0]; + best[1] = best[0]; + + // Store current as new best. + min[0] = d; + best[0] = n; + } else if (d < min[1]) { + // Replace old second best with current. + min[1] = d; + best[1] = n; + } + } + + return new Pair<Neuron, Neuron>(best[0], best[1]); + } + + /** + * Creates a list of neurons sorted in increased order of the distance + * to the given {@code features}. + * + * @param features Data. + * @param neurons List of neurons to scan. If it is empty, an empty array + * will be returned. + * @param distance Distance function. + * @return the neurons, sorted in increasing order of distance in data + * space. + * @throws org.apache.commons.math3.exception.DimensionMismatchException + * if the size of the input is not compatible with the neurons features + * size. + * + * @see #findBest(double[],Iterable,DistanceMeasure) + * @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure) + * + * @since 3.6 + */ + public static Neuron[] sort(double[] features, + Iterable<Neuron> neurons, + DistanceMeasure distance) { + final List<PairNeuronDouble> list = new ArrayList<PairNeuronDouble>(); + + for (final Neuron n : neurons) { + final double d = distance.compute(n.getFeatures(), features); + list.add(new PairNeuronDouble(n, d)); + } + + Collections.sort(list, PairNeuronDouble.COMPARATOR); + + final int len = list.size(); + final Neuron[] sorted = new Neuron[len]; + + for (int i = 0; i < len; i++) { + sorted[i] = list.get(i).getNeuron(); + } + return sorted; + } + + /** + * Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix"> + * U-matrix</a> of a two-dimensional map. + * + * @param map Network. + * @param distance Function to use for computing the average + * distance from a neuron to its neighbours. + * @return the matrix of average distances. + */ + public static double[][] computeU(NeuronSquareMesh2D map, + DistanceMeasure distance) { + final int numRows = map.getNumberOfRows(); + final int numCols = map.getNumberOfColumns(); + final double[][] uMatrix = new double[numRows][numCols]; + + final Network net = map.getNetwork(); + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + final Neuron neuron = map.getNeuron(i, j); + final Collection<Neuron> neighbours = net.getNeighbours(neuron); + final double[] features = neuron.getFeatures(); + + double d = 0; + int count = 0; + for (Neuron n : neighbours) { + ++count; + d += distance.compute(features, n.getFeatures()); + } + + uMatrix[i][j] = d / count; + } + } + + return uMatrix; + } + + /** + * Computes the "hit" histogram of a two-dimensional map. + * + * @param data Feature vectors. + * @param map Network. + * @param distance Function to use for determining the best matching unit. + * @return the number of hits for each neuron in the map. + */ + public static int[][] computeHitHistogram(Iterable<double[]> data, + NeuronSquareMesh2D map, + DistanceMeasure distance) { + final HashMap<Neuron, Integer> hit = new HashMap<Neuron, Integer>(); + final Network net = map.getNetwork(); + + for (double[] f : data) { + final Neuron best = findBest(f, net, distance); + final Integer count = hit.get(best); + if (count == null) { + hit.put(best, 1); + } else { + hit.put(best, count + 1); + } + } + + // Copy the histogram data into a 2D map. + final int numRows = map.getNumberOfRows(); + final int numCols = map.getNumberOfColumns(); + final int[][] histo = new int[numRows][numCols]; + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + final Neuron neuron = map.getNeuron(i, j); + final Integer count = hit.get(neuron); + if (count == null) { + histo[i][j] = 0; + } else { + histo[i][j] = count; + } + } + } + + return histo; + } + + /** + * Computes the quantization error. + * The quantization error is the average distance between a feature vector + * and its "best matching unit" (closest neuron). + * + * @param data Feature vectors. + * @param neurons List of neurons to scan. + * @param distance Distance function. + * @return the error. + * @throws NoDataException if {@code data} is empty. + */ + public static double computeQuantizationError(Iterable<double[]> data, + Iterable<Neuron> neurons, + DistanceMeasure distance) { + double d = 0; + int count = 0; + for (double[] f : data) { + ++count; + d += distance.compute(f, findBest(f, neurons, distance).getFeatures()); + } + + if (count == 0) { + throw new NoDataException(); + } + + return d / count; + } + + /** + * Computes the topographic error. + * The topographic error is the proportion of data for which first and + * second best matching units are not adjacent in the map. + * + * @param data Feature vectors. + * @param net Network. + * @param distance Distance function. + * @return the error. + * @throws NoDataException if {@code data} is empty. + */ + public static double computeTopographicError(Iterable<double[]> data, + Network net, + DistanceMeasure distance) { + int notAdjacentCount = 0; + int count = 0; + for (double[] f : data) { + ++count; + final Pair<Neuron, Neuron> p = findBestAndSecondBest(f, net, distance); + if (!net.getNeighbours(p.getFirst()).contains(p.getSecond())) { + // Increment count if first and second best matching units + // are not neighbours. + ++notAdjacentCount; + } + } + + if (count == 0) { + throw new NoDataException(); + } + + return ((double) notAdjacentCount) / count; + } + + /** + * Helper data structure holding a (Neuron, double) pair. + */ + private static class PairNeuronDouble { + /** Comparator. */ + static final Comparator<PairNeuronDouble> COMPARATOR + = new Comparator<PairNeuronDouble>() { + /** {@inheritDoc} */ + public int compare(PairNeuronDouble o1, + PairNeuronDouble o2) { + return Double.compare(o1.value, o2.value); + } + }; + /** Key. */ + private final Neuron neuron; + /** Value. */ + private final double value; + + /** + * @param neuron Neuron. + * @param value Value. + */ + PairNeuronDouble(Neuron neuron, double value) { + this.neuron = neuron; + this.value = value; + } + + /** @return the neuron. */ + public Neuron getNeuron() { + return neuron; + } + + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java new file mode 100644 index 0000000..4b208a3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +import java.io.Serializable; +import java.io.ObjectInputStream; +import java.util.NoSuchElementException; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.HashSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.Comparator; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.MathIllegalStateException; + +/** + * Neural network, composed of {@link Neuron} instances and the links + * between them. + * + * Although updating a neuron's state is thread-safe, modifying the + * network's topology (adding or removing links) is not. + * + * @since 3.3 + */ +public class Network + implements Iterable<Neuron>, + Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130207L; + /** Neurons. */ + private final ConcurrentHashMap<Long, Neuron> neuronMap + = new ConcurrentHashMap<Long, Neuron>(); + /** Next available neuron identifier. */ + private final AtomicLong nextId; + /** Neuron's features set size. */ + private final int featureSize; + /** Links. */ + private final ConcurrentHashMap<Long, Set<Long>> linkMap + = new ConcurrentHashMap<Long, Set<Long>>(); + + /** + * Comparator that prescribes an order of the neurons according + * to the increasing order of their identifier. + */ + public static class NeuronIdentifierComparator + implements Comparator<Neuron>, + Serializable { + /** Version identifier. */ + private static final long serialVersionUID = 20130207L; + + /** {@inheritDoc} */ + public int compare(Neuron a, + Neuron b) { + final long aId = a.getIdentifier(); + final long bId = b.getIdentifier(); + return aId < bId ? -1 : + aId > bId ? 1 : 0; + } + } + + /** + * Constructor with restricted access, solely used for deserialization. + * + * @param nextId Next available identifier. + * @param featureSize Number of features. + * @param neuronList Neurons. + * @param neighbourIdList Links associated to each of the neurons in + * {@code neuronList}. + * @throws MathIllegalStateException if an inconsistency is detected + * (which probably means that the serialized form has been corrupted). + */ + Network(long nextId, + int featureSize, + Neuron[] neuronList, + long[][] neighbourIdList) { + final int numNeurons = neuronList.length; + if (numNeurons != neighbourIdList.length) { + throw new MathIllegalStateException(); + } + + for (int i = 0; i < numNeurons; i++) { + final Neuron n = neuronList[i]; + final long id = n.getIdentifier(); + if (id >= nextId) { + throw new MathIllegalStateException(); + } + neuronMap.put(id, n); + linkMap.put(id, new HashSet<Long>()); + } + + for (int i = 0; i < numNeurons; i++) { + final long aId = neuronList[i].getIdentifier(); + final Set<Long> aLinks = linkMap.get(aId); + for (Long bId : neighbourIdList[i]) { + if (neuronMap.get(bId) == null) { + throw new MathIllegalStateException(); + } + addLinkToLinkSet(aLinks, bId); + } + } + + this.nextId = new AtomicLong(nextId); + this.featureSize = featureSize; + } + + /** + * @param initialIdentifier Identifier for the first neuron that + * will be added to this network. + * @param featureSize Size of the neuron's features. + */ + public Network(long initialIdentifier, + int featureSize) { + nextId = new AtomicLong(initialIdentifier); + this.featureSize = featureSize; + } + + /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + * @since 3.6 + */ + public synchronized Network copy() { + final Network copy = new Network(nextId.get(), + featureSize); + + + for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { + copy.neuronMap.put(e.getKey(), e.getValue().copy()); + } + + for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { + copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue())); + } + + return copy; + } + + /** + * {@inheritDoc} + */ + public Iterator<Neuron> iterator() { + return neuronMap.values().iterator(); + } + + /** + * Creates a list of the neurons, sorted in a custom order. + * + * @param comparator {@link Comparator} used for sorting the neurons. + * @return a list of neurons, sorted in the order prescribed by the + * given {@code comparator}. + * @see NeuronIdentifierComparator + */ + public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) { + final List<Neuron> neurons = new ArrayList<Neuron>(); + neurons.addAll(neuronMap.values()); + + Collections.sort(neurons, comparator); + + return neurons; + } + + /** + * Creates a neuron and assigns it a unique identifier. + * + * @param features Initial values for the neuron's features. + * @return the neuron's identifier. + * @throws DimensionMismatchException if the length of {@code features} + * is different from the expected size (as set by the + * {@link #Network(long,int) constructor}). + */ + public long createNeuron(double[] features) { + if (features.length != featureSize) { + throw new DimensionMismatchException(features.length, featureSize); + } + + final long id = createNextId(); + neuronMap.put(id, new Neuron(id, features)); + linkMap.put(id, new HashSet<Long>()); + return id; + } + + /** + * Deletes a neuron. + * Links from all neighbours to the removed neuron will also be + * {@link #deleteLink(Neuron,Neuron) deleted}. + * + * @param neuron Neuron to be removed from this network. + * @throws NoSuchElementException if {@code n} does not belong to + * this network. + */ + public void deleteNeuron(Neuron neuron) { + final Collection<Neuron> neighbours = getNeighbours(neuron); + + // Delete links to from neighbours. + for (Neuron n : neighbours) { + deleteLink(n, neuron); + } + + // Remove neuron. + neuronMap.remove(neuron.getIdentifier()); + } + + /** + * Gets the size of the neurons' features set. + * + * @return the size of the features set. + */ + public int getFeaturesSize() { + return featureSize; + } + + /** + * Adds a link from neuron {@code a} to neuron {@code b}. + * Note: the link is not bi-directional; if a bi-directional link is + * required, an additional call must be made with {@code a} and + * {@code b} exchanged in the argument list. + * + * @param a Neuron. + * @param b Neuron. + * @throws NoSuchElementException if the neurons do not exist in the + * network. + */ + public void addLink(Neuron a, + Neuron b) { + final long aId = a.getIdentifier(); + final long bId = b.getIdentifier(); + + // Check that the neurons belong to this network. + if (a != getNeuron(aId)) { + throw new NoSuchElementException(Long.toString(aId)); + } + if (b != getNeuron(bId)) { + throw new NoSuchElementException(Long.toString(bId)); + } + + // Add link from "a" to "b". + addLinkToLinkSet(linkMap.get(aId), bId); + } + + /** + * Adds a link to neuron {@code id} in given {@code linkSet}. + * Note: no check verifies that the identifier indeed belongs + * to this network. + * + * @param linkSet Neuron identifier. + * @param id Neuron identifier. + */ + private void addLinkToLinkSet(Set<Long> linkSet, + long id) { + linkSet.add(id); + } + + /** + * Deletes the link between neurons {@code a} and {@code b}. + * + * @param a Neuron. + * @param b Neuron. + * @throws NoSuchElementException if the neurons do not exist in the + * network. + */ + public void deleteLink(Neuron a, + Neuron b) { + final long aId = a.getIdentifier(); + final long bId = b.getIdentifier(); + + // Check that the neurons belong to this network. + if (a != getNeuron(aId)) { + throw new NoSuchElementException(Long.toString(aId)); + } + if (b != getNeuron(bId)) { + throw new NoSuchElementException(Long.toString(bId)); + } + + // Delete link from "a" to "b". + deleteLinkFromLinkSet(linkMap.get(aId), bId); + } + + /** + * Deletes a link to neuron {@code id} in given {@code linkSet}. + * Note: no check verifies that the identifier indeed belongs + * to this network. + * + * @param linkSet Neuron identifier. + * @param id Neuron identifier. + */ + private void deleteLinkFromLinkSet(Set<Long> linkSet, + long id) { + linkSet.remove(id); + } + + /** + * Retrieves the neuron with the given (unique) {@code id}. + * + * @param id Identifier. + * @return the neuron associated with the given {@code id}. + * @throws NoSuchElementException if the neuron does not exist in the + * network. + */ + public Neuron getNeuron(long id) { + final Neuron n = neuronMap.get(id); + if (n == null) { + throw new NoSuchElementException(Long.toString(id)); + } + return n; + } + + /** + * Retrieves the neurons in the neighbourhood of any neuron in the + * {@code neurons} list. + * @param neurons Neurons for which to retrieve the neighbours. + * @return the list of neighbours. + * @see #getNeighbours(Iterable,Iterable) + */ + public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) { + return getNeighbours(neurons, null); + } + + /** + * Retrieves the neurons in the neighbourhood of any neuron in the + * {@code neurons} list. + * The {@code exclude} list allows to retrieve the "concentric" + * neighbourhoods by removing the neurons that belong to the inner + * "circles". + * + * @param neurons Neurons for which to retrieve the neighbours. + * @param exclude Neurons to exclude from the returned list. + * Can be {@code null}. + * @return the list of neighbours. + */ + public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons, + Iterable<Neuron> exclude) { + final Set<Long> idList = new HashSet<Long>(); + + for (Neuron n : neurons) { + idList.addAll(linkMap.get(n.getIdentifier())); + } + if (exclude != null) { + for (Neuron n : exclude) { + idList.remove(n.getIdentifier()); + } + } + + final List<Neuron> neuronList = new ArrayList<Neuron>(); + for (Long id : idList) { + neuronList.add(getNeuron(id)); + } + + return neuronList; + } + + /** + * Retrieves the neighbours of the given neuron. + * + * @param neuron Neuron for which to retrieve the neighbours. + * @return the list of neighbours. + * @see #getNeighbours(Neuron,Iterable) + */ + public Collection<Neuron> getNeighbours(Neuron neuron) { + return getNeighbours(neuron, null); + } + + /** + * Retrieves the neighbours of the given neuron. + * + * @param neuron Neuron for which to retrieve the neighbours. + * @param exclude Neurons to exclude from the returned list. + * Can be {@code null}. + * @return the list of neighbours. + */ + public Collection<Neuron> getNeighbours(Neuron neuron, + Iterable<Neuron> exclude) { + final Set<Long> idList = linkMap.get(neuron.getIdentifier()); + if (exclude != null) { + for (Neuron n : exclude) { + idList.remove(n.getIdentifier()); + } + } + + final List<Neuron> neuronList = new ArrayList<Neuron>(); + for (Long id : idList) { + neuronList.add(getNeuron(id)); + } + + return neuronList; + } + + /** + * Creates a neuron identifier. + * + * @return a value that will serve as a unique identifier. + */ + private Long createNextId() { + return nextId.getAndIncrement(); + } + + /** + * Prevents proxy bypass. + * + * @param in Input stream. + */ + private void readObject(ObjectInputStream in) { + throw new IllegalStateException(); + } + + /** + * Custom serialization. + * + * @return the proxy instance that will be actually serialized. + */ + private Object writeReplace() { + final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]); + final long[][] neighbourIdList = new long[neuronList.length][]; + + for (int i = 0; i < neuronList.length; i++) { + final Collection<Neuron> neighbours = getNeighbours(neuronList[i]); + final long[] neighboursId = new long[neighbours.size()]; + int count = 0; + for (Neuron n : neighbours) { + neighboursId[count] = n.getIdentifier(); + ++count; + } + neighbourIdList[i] = neighboursId; + } + + return new SerializationProxy(nextId.get(), + featureSize, + neuronList, + neighbourIdList); + } + + /** + * Serialization. + */ + private static class SerializationProxy implements Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130207L; + /** Next identifier. */ + private final long nextId; + /** Number of features. */ + private final int featureSize; + /** Neurons. */ + private final Neuron[] neuronList; + /** Links. */ + private final long[][] neighbourIdList; + + /** + * @param nextId Next available identifier. + * @param featureSize Number of features. + * @param neuronList Neurons. + * @param neighbourIdList Links associated to each of the neurons in + * {@code neuronList}. + */ + SerializationProxy(long nextId, + int featureSize, + Neuron[] neuronList, + long[][] neighbourIdList) { + this.nextId = nextId; + this.featureSize = featureSize; + this.neuronList = neuronList; + this.neighbourIdList = neighbourIdList; + } + + /** + * Custom serialization. + * + * @return the {@link Network} for which this instance is the proxy. + */ + private Object readResolve() { + return new Network(nextId, + featureSize, + neuronList, + neighbourIdList); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java new file mode 100644 index 0000000..8cae3ea --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +import java.io.Serializable; +import java.io.ObjectInputStream; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.util.Precision; + + +/** + * Describes a neuron element of a neural network. + * + * This class aims to be thread-safe. + * + * @since 3.3 + */ +public class Neuron implements Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130207L; + /** Identifier. */ + private final long identifier; + /** Length of the feature set. */ + private final int size; + /** Neuron data. */ + private final AtomicReference<double[]> features; + /** Number of attempts to update a neuron. */ + private final AtomicLong numberOfAttemptedUpdates = new AtomicLong(0); + /** Number of successful updates of a neuron. */ + private final AtomicLong numberOfSuccessfulUpdates = new AtomicLong(0); + + /** + * Creates a neuron. + * The size of the feature set is fixed to the length of the given + * argument. + * <br/> + * Constructor is package-private: Neurons must be + * {@link Network#createNeuron(double[]) created} by the network + * instance to which they will belong. + * + * @param identifier Identifier (assigned by the {@link Network}). + * @param features Initial values of the feature set. + */ + Neuron(long identifier, + double[] features) { + this.identifier = identifier; + this.size = features.length; + this.features = new AtomicReference<double[]>(features.clone()); + } + + /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + * @since 3.6 + */ + public synchronized Neuron copy() { + final Neuron copy = new Neuron(getIdentifier(), + getFeatures()); + copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get()); + copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get()); + + return copy; + } + + /** + * Gets the neuron's identifier. + * + * @return the identifier. + */ + public long getIdentifier() { + return identifier; + } + + /** + * Gets the length of the feature set. + * + * @return the number of features. + */ + public int getSize() { + return size; + } + + /** + * Gets the neuron's features. + * + * @return a copy of the neuron's features. + */ + public double[] getFeatures() { + return features.get().clone(); + } + + /** + * Tries to atomically update the neuron's features. + * Update will be performed only if the expected values match the + * current values.<br/> + * In effect, when concurrent threads call this method, the state + * could be modified by one, so that it does not correspond to the + * the state assumed by another. + * Typically, a caller {@link #getFeatures() retrieves the current state}, + * and uses it to compute the new state. + * During this computation, another thread might have done the same + * thing, and updated the state: If the current thread were to proceed + * with its own update, it would overwrite the new state (which might + * already have been used by yet other threads). + * To prevent this, the method does not perform the update when a + * concurrent modification has been detected, and returns {@code false}. + * When this happens, the caller should fetch the new current state, + * redo its computation, and call this method again. + * + * @param expect Current values of the features, as assumed by the caller. + * Update will never succeed if the contents of this array does not match + * the values returned by {@link #getFeatures()}. + * @param update Features's new values. + * @return {@code true} if the update was successful, {@code false} + * otherwise. + * @throws DimensionMismatchException if the length of {@code update} is + * not the same as specified in the {@link #Neuron(long,double[]) + * constructor}. + */ + public boolean compareAndSetFeatures(double[] expect, + double[] update) { + if (update.length != size) { + throw new DimensionMismatchException(update.length, size); + } + + // Get the internal reference. Note that this must not be a copy; + // otherwise the "compareAndSet" below will always fail. + final double[] current = features.get(); + if (!containSameValues(current, expect)) { + // Some other thread already modified the state. + return false; + } + + // Increment attempt counter. + numberOfAttemptedUpdates.incrementAndGet(); + + if (features.compareAndSet(current, update.clone())) { + // The current thread could atomically update the state (attempt succeeded). + numberOfSuccessfulUpdates.incrementAndGet(); + return true; + } else { + // Some other thread came first (attempt failed). + return false; + } + } + + /** + * Retrieves the number of calls to the + * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures} + * method. + * Note that if the caller wants to use this method in combination with + * {@link #getNumberOfSuccessfulUpdates()}, additional synchronization + * may be required to ensure consistency. + * + * @return the number of update attempts. + * @since 3.6 + */ + public long getNumberOfAttemptedUpdates() { + return numberOfAttemptedUpdates.get(); + } + + /** + * Retrieves the number of successful calls to the + * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures} + * method. + * Note that if the caller wants to use this method in combination with + * {@link #getNumberOfAttemptedUpdates()}, additional synchronization + * may be required to ensure consistency. + * + * @return the number of successful updates. + * @since 3.6 + */ + public long getNumberOfSuccessfulUpdates() { + return numberOfSuccessfulUpdates.get(); + } + + /** + * Checks whether the contents of both arrays is the same. + * + * @param current Current values. + * @param expect Expected values. + * @throws DimensionMismatchException if the length of {@code expected} + * is not the same as specified in the {@link #Neuron(long,double[]) + * constructor}. + * @return {@code true} if the arrays contain the same values. + */ + private boolean containSameValues(double[] current, + double[] expect) { + if (expect.length != size) { + throw new DimensionMismatchException(expect.length, size); + } + + for (int i = 0; i < size; i++) { + if (!Precision.equals(current[i], expect[i])) { + return false; + } + } + return true; + } + + /** + * Prevents proxy bypass. + * + * @param in Input stream. + */ + private void readObject(ObjectInputStream in) { + throw new IllegalStateException(); + } + + /** + * Custom serialization. + * + * @return the proxy instance that will be actually serialized. + */ + private Object writeReplace() { + return new SerializationProxy(identifier, + features.get()); + } + + /** + * Serialization. + */ + private static class SerializationProxy implements Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130207L; + /** Features. */ + private final double[] features; + /** Identifier. */ + private final long identifier; + + /** + * @param identifier Identifier. + * @param features Features. + */ + SerializationProxy(long identifier, + double[] features) { + this.identifier = identifier; + this.features = features; + } + + /** + * Custom serialization. + * + * @return the {@link Neuron} for which this instance is the proxy. + */ + private Object readResolve() { + return new Neuron(identifier, + features); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java new file mode 100644 index 0000000..a3c0d95 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +/** + * Defines neighbourhood types. + * + * @since 3.3 + */ +public enum SquareNeighbourhood { + /** + * <a href="http://en.wikipedia.org/wiki/Von_Neumann_neighborhood" + * Von Neumann neighbourhood</a>: in two dimensions, each (internal) + * neuron has four neighbours. + */ + VON_NEUMANN, + /** + * <a href="http://en.wikipedia.org/wiki/Moore_neighborhood" + * Moore neighbourhood</a>: in two dimensions, each (internal) + * neuron has eight neighbours. + */ + MOORE, +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java new file mode 100644 index 0000000..041d3d6 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet; + +/** + * Describes how to update the network in response to a training + * sample. + * + * @since 3.3 + */ +public interface UpdateAction { + /** + * Updates the network in response to the sample {@code features}. + * + * @param net Network. + * @param features Training data. + */ + void update(Network net, double[] features); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java new file mode 100644 index 0000000..fad6042 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.oned; + +import java.io.Serializable; +import java.io.ObjectInputStream; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.FeatureInitializer; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; + +/** + * Neural network with the topology of a one-dimensional line. + * Each neuron defines one point on the line. + * + * @since 3.3 + */ +public class NeuronString implements Serializable { + /** Serial version ID */ + private static final long serialVersionUID = 1L; + /** Underlying network. */ + private final Network network; + /** Number of neurons. */ + private final int size; + /** Wrap. */ + private final boolean wrap; + + /** + * Mapping of the 1D coordinate to the neuron identifiers + * (attributed by the {@link #network} instance). + */ + private final long[] identifiers; + + /** + * Constructor with restricted access, solely used for deserialization. + * + * @param wrap Whether to wrap the dimension (i.e the first and last + * neurons will be linked together). + * @param featuresList Arrays that will initialize the features sets of + * the network's neurons. + * @throws NumberIsTooSmallException if {@code num < 2}. + */ + NeuronString(boolean wrap, + double[][] featuresList) { + size = featuresList.length; + + if (size < 2) { + throw new NumberIsTooSmallException(size, 2, true); + } + + this.wrap = wrap; + + final int fLen = featuresList[0].length; + network = new Network(0, fLen); + identifiers = new long[size]; + + // Add neurons. + for (int i = 0; i < size; i++) { + identifiers[i] = network.createNeuron(featuresList[i]); + } + + // Add links. + createLinks(); + } + + /** + * Creates a one-dimensional network: + * Each neuron not located on the border of the mesh has two + * neurons linked to it. + * <br/> + * The links are bi-directional. + * Neurons created successively are neighbours (i.e. there are + * links between them). + * <br/> + * The topology of the network can also be a circle (if the + * dimension is wrapped). + * + * @param num Number of neurons. + * @param wrap Whether to wrap the dimension (i.e the first and last + * neurons will be linked together). + * @param featureInit Arrays that will initialize the features sets of + * the network's neurons. + * @throws NumberIsTooSmallException if {@code num < 2}. + */ + public NeuronString(int num, + boolean wrap, + FeatureInitializer[] featureInit) { + if (num < 2) { + throw new NumberIsTooSmallException(num, 2, true); + } + + size = num; + this.wrap = wrap; + identifiers = new long[num]; + + final int fLen = featureInit.length; + network = new Network(0, fLen); + + // Add neurons. + for (int i = 0; i < num; i++) { + final double[] features = new double[fLen]; + for (int fIndex = 0; fIndex < fLen; fIndex++) { + features[fIndex] = featureInit[fIndex].value(); + } + identifiers[i] = network.createNeuron(features); + } + + // Add links. + createLinks(); + } + + /** + * Retrieves the underlying network. + * A reference is returned (enabling, for example, the network to be + * trained). + * This also implies that calling methods that modify the {@link Network} + * topology may cause this class to become inconsistent. + * + * @return the network. + */ + public Network getNetwork() { + return network; + } + + /** + * Gets the number of neurons. + * + * @return the number of neurons. + */ + public int getSize() { + return size; + } + + /** + * Retrieves the features set from the neuron at location + * {@code i} in the map. + * + * @param i Neuron index. + * @return the features of the neuron at index {@code i}. + * @throws OutOfRangeException if {@code i} is out of range. + */ + public double[] getFeatures(int i) { + if (i < 0 || + i >= size) { + throw new OutOfRangeException(i, 0, size - 1); + } + + return network.getNeuron(identifiers[i]).getFeatures(); + } + + /** + * Creates the neighbour relationships between neurons. + */ + private void createLinks() { + for (int i = 0; i < size - 1; i++) { + network.addLink(network.getNeuron(i), network.getNeuron(i + 1)); + } + for (int i = size - 1; i > 0; i--) { + network.addLink(network.getNeuron(i), network.getNeuron(i - 1)); + } + if (wrap) { + network.addLink(network.getNeuron(0), network.getNeuron(size - 1)); + network.addLink(network.getNeuron(size - 1), network.getNeuron(0)); + } + } + + /** + * Prevents proxy bypass. + * + * @param in Input stream. + */ + private void readObject(ObjectInputStream in) { + throw new IllegalStateException(); + } + + /** + * Custom serialization. + * + * @return the proxy instance that will be actually serialized. + */ + private Object writeReplace() { + final double[][] featuresList = new double[size][]; + for (int i = 0; i < size; i++) { + featuresList[i] = getFeatures(i); + } + + return new SerializationProxy(wrap, + featuresList); + } + + /** + * Serialization. + */ + private static class SerializationProxy implements Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130226L; + /** Wrap. */ + private final boolean wrap; + /** Neurons' features. */ + private final double[][] featuresList; + + /** + * @param wrap Whether the dimension is wrapped. + * @param featuresList List of neurons features. + * {@code neuronList}. + */ + SerializationProxy(boolean wrap, + double[][] featuresList) { + this.wrap = wrap; + this.featuresList = featuresList; + } + + /** + * Custom serialization. + * + * @return the {@link Neuron} for which this instance is the proxy. + */ + private Object readResolve() { + return new NeuronString(wrap, + featuresList); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java new file mode 100644 index 0000000..0b47fae --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * One-dimensional neural networks. + */ + +package org.apache.commons.math3.ml.neuralnet.oned; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java new file mode 100644 index 0000000..d8e907e --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Neural networks. + */ + +package org.apache.commons.math3.ml.neuralnet; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java new file mode 100644 index 0000000..9aa497d --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import java.util.Iterator; +import org.apache.commons.math3.ml.neuralnet.Network; + +/** + * Trainer for Kohonen's Self-Organizing Map. + * + * @since 3.3 + */ +public class KohonenTrainingTask implements Runnable { + /** SOFM to be trained. */ + private final Network net; + /** Training data. */ + private final Iterator<double[]> featuresIterator; + /** Update procedure. */ + private final KohonenUpdateAction updateAction; + + /** + * Creates a (sequential) trainer for the given network. + * + * @param net Network to be trained with the SOFM algorithm. + * @param featuresIterator Training data iterator. + * @param updateAction SOFM update procedure. + */ + public KohonenTrainingTask(Network net, + Iterator<double[]> featuresIterator, + KohonenUpdateAction updateAction) { + this.net = net; + this.featuresIterator = featuresIterator; + this.updateAction = updateAction; + } + + /** + * {@inheritDoc} + */ + public void run() { + while (featuresIterator.hasNext()) { + updateAction.update(net, featuresIterator.next()); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java new file mode 100644 index 0000000..0618aeb --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import java.util.Collection; +import java.util.HashSet; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.commons.math3.analysis.function.Gaussian; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.UpdateAction; + +/** + * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> + * Kohonen's Self-Organizing Map</a>. + * <br/> + * The {@link #update(Network,double[]) update} method modifies the + * features {@code w} of the "winning" neuron and its neighbours + * according to the following rule: + * <code> + * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>) + * </code> + * where + * <ul> + * <li>α is the current <em>learning rate</em>, </li> + * <li>σ is the current <em>neighbourhood size</em>, and</li> + * <li>{@code d} is the number of links to traverse in order to reach + * the neuron from the winning neuron.</li> + * </ul> + * <br/> + * This class is thread-safe as long as the arguments passed to the + * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction, + * NeighbourhoodSizeFunction) constructor} are instances of thread-safe + * classes. + * <br/> + * Each call to the {@link #update(Network,double[]) update} method + * will increment the internal counter used to compute the current + * values for + * <ul> + * <li>the <em>learning rate</em>, and</li> + * <li>the <em>neighbourhood size</em>.</li> + * </ul> + * Consequently, the function instances that compute those values (passed + * to the constructor of this class) must take into account whether this + * class's instance will be shared by multiple threads, as this will impact + * the training process. + * + * @since 3.3 + */ +public class KohonenUpdateAction implements UpdateAction { + /** Distance function. */ + private final DistanceMeasure distance; + /** Learning factor update function. */ + private final LearningFactorFunction learningFactor; + /** Neighbourhood size update function. */ + private final NeighbourhoodSizeFunction neighbourhoodSize; + /** Number of calls to {@link #update(Network,double[])}. */ + private final AtomicLong numberOfCalls = new AtomicLong(0); + + /** + * @param distance Distance function. + * @param learningFactor Learning factor update function. + * @param neighbourhoodSize Neighbourhood size update function. + */ + public KohonenUpdateAction(DistanceMeasure distance, + LearningFactorFunction learningFactor, + NeighbourhoodSizeFunction neighbourhoodSize) { + this.distance = distance; + this.learningFactor = learningFactor; + this.neighbourhoodSize = neighbourhoodSize; + } + + /** + * {@inheritDoc} + */ + public void update(Network net, + double[] features) { + final long numCalls = numberOfCalls.incrementAndGet() - 1; + final double currentLearning = learningFactor.value(numCalls); + final Neuron best = findAndUpdateBestNeuron(net, + features, + currentLearning); + + final int currentNeighbourhood = neighbourhoodSize.value(numCalls); + // The farther away the neighbour is from the winning neuron, the + // smaller the learning rate will become. + final Gaussian neighbourhoodDecay + = new Gaussian(currentLearning, + 0, + currentNeighbourhood); + + if (currentNeighbourhood > 0) { + // Initial set of neurons only contains the winning neuron. + Collection<Neuron> neighbours = new HashSet<Neuron>(); + neighbours.add(best); + // Winning neuron must be excluded from the neighbours. + final HashSet<Neuron> exclude = new HashSet<Neuron>(); + exclude.add(best); + + int radius = 1; + do { + // Retrieve immediate neighbours of the current set of neurons. + neighbours = net.getNeighbours(neighbours, exclude); + + // Update all the neighbours. + for (Neuron n : neighbours) { + updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius)); + } + + // Add the neighbours to the exclude list so that they will + // not be update more than once per training step. + exclude.addAll(neighbours); + ++radius; + } while (radius <= currentNeighbourhood); + } + } + + /** + * Retrieves the number of calls to the {@link #update(Network,double[]) update} + * method. + * + * @return the current number of calls. + */ + public long getNumberOfCalls() { + return numberOfCalls.get(); + } + + /** + * Tries to update a neuron. + * + * @param n Neuron to be updated. + * @param features Training data. + * @param learningRate Learning factor. + * @return {@code true} if the update succeeded, {@code true} if a + * concurrent update has been detected. + */ + private boolean attemptNeuronUpdate(Neuron n, + double[] features, + double learningRate) { + final double[] expect = n.getFeatures(); + final double[] update = computeFeatures(expect, + features, + learningRate); + + return n.compareAndSetFeatures(expect, update); + } + + /** + * Atomically updates the given neuron. + * + * @param n Neuron to be updated. + * @param features Training data. + * @param learningRate Learning factor. + */ + private void updateNeighbouringNeuron(Neuron n, + double[] features, + double learningRate) { + while (true) { + if (attemptNeuronUpdate(n, features, learningRate)) { + break; + } + } + } + + /** + * Searches for the neuron whose features are closest to the given + * sample, and atomically updates its features. + * + * @param net Network. + * @param features Sample data. + * @param learningRate Current learning factor. + * @return the winning neuron. + */ + private Neuron findAndUpdateBestNeuron(Network net, + double[] features, + double learningRate) { + while (true) { + final Neuron best = MapUtils.findBest(features, net, distance); + + if (attemptNeuronUpdate(best, features, learningRate)) { + return best; + } + + // If another thread modified the state of the winning neuron, + // it may not be the best match anymore for the given training + // sample: Hence, the winner search is performed again. + } + } + + /** + * Computes the new value of the features set. + * + * @param current Current values of the features. + * @param sample Training data. + * @param learningRate Learning factor. + * @return the new values for the features. + */ + private double[] computeFeatures(double[] current, + double[] sample, + double learningRate) { + final ArrayRealVector c = new ArrayRealVector(current, false); + final ArrayRealVector s = new ArrayRealVector(sample, false); + // c + learningRate * (s - c) + return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray(); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java new file mode 100644 index 0000000..ba9d152 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +/** + * Provides the learning rate as a function of the number of calls + * already performed during the learning task. + * + * @since 3.3 + */ +public interface LearningFactorFunction { + /** + * Computes the learning rate at the current call. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + double value(long numCall); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java new file mode 100644 index 0000000..9165e82 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction; +import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction; +import org.apache.commons.math3.exception.OutOfRangeException; + +/** + * Factory for creating instances of {@link LearningFactorFunction}. + * + * @since 3.3 + */ +public class LearningFactorFunctionFactory { + /** Class contains only static methods. */ + private LearningFactorFunctionFactory() {} + + /** + * Creates an exponential decay {@link LearningFactorFunction function}. + * It will compute <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable and + * <ul> + * <li><code>a = initValue</code> + * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link LearningFactorFunction#value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @return the learning factor function. + * @throws org.apache.commons.math3.exception.OutOfRangeException + * if {@code initValue <= 0} or {@code initValue > 1}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code valueAtNumCall <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code valueAtNumCall >= initValue}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static LearningFactorFunction exponentialDecay(final double initValue, + final double valueAtNumCall, + final long numCall) { + if (initValue <= 0 || + initValue > 1) { + throw new OutOfRangeException(initValue, 0, 1); + } + + return new LearningFactorFunction() { + /** DecayFunction. */ + private final ExponentialDecayFunction decay + = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall); + + /** {@inheritDoc} */ + public double value(long n) { + return decay.value(n); + } + }; + } + + /** + * Creates an sigmoid-like {@code LearningFactorFunction function}. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link LearningFactorFunction#value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @return the learning factor function. + * @throws org.apache.commons.math3.exception.OutOfRangeException + * if {@code initValue <= 0} or {@code initValue > 1}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code slope >= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static LearningFactorFunction quasiSigmoidDecay(final double initValue, + final double slope, + final long numCall) { + if (initValue <= 0 || + initValue > 1) { + throw new OutOfRangeException(initValue, 0, 1); + } + + return new LearningFactorFunction() { + /** DecayFunction. */ + private final QuasiSigmoidDecayFunction decay + = new QuasiSigmoidDecayFunction(initValue, slope, numCall); + + /** {@inheritDoc} */ + public double value(long n) { + return decay.value(n); + } + }; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java new file mode 100644 index 0000000..68149f7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +/** + * Provides the network neighbourhood's size as a function of the + * number of calls already performed during the learning task. + * The "neighbourhood" is the set of neurons that can be reached + * by traversing at most the number of links returned by this + * function. + * + * @since 3.3 + */ +public interface NeighbourhoodSizeFunction { + /** + * Computes the neighbourhood size at the current call. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + int value(long numCall); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java new file mode 100644 index 0000000..bdbfa2f --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; + +import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction; +import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction; +import org.apache.commons.math3.util.FastMath; + +/** + * Factory for creating instances of {@link NeighbourhoodSizeFunction}. + * + * @since 3.3 + */ +public class NeighbourhoodSizeFunctionFactory { + /** Class contains only static methods. */ + private NeighbourhoodSizeFunctionFactory() {} + + /** + * Creates an exponential decay {@link NeighbourhoodSizeFunction function}. + * It will compute <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable and + * <ul> + * <li><code>a = initValue</code> + * <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link NeighbourhoodSizeFunction#value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @return the neighbourhood size function. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code initValue <= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code valueAtNumCall <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code valueAtNumCall >= initValue}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static NeighbourhoodSizeFunction exponentialDecay(final double initValue, + final double valueAtNumCall, + final long numCall) { + return new NeighbourhoodSizeFunction() { + /** DecayFunction. */ + private final ExponentialDecayFunction decay + = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall); + + /** {@inheritDoc} */ + public int value(long n) { + return (int) FastMath.rint(decay.value(n)); + } + }; + } + + /** + * Creates an sigmoid-like {@code NeighbourhoodSizeFunction function}. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. + * {@link NeighbourhoodSizeFunction#value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @return the neighbourhood size function. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code initValue <= 0}. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException + * if {@code slope >= 0}. + * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException + * if {@code numCall <= 0}. + */ + public static NeighbourhoodSizeFunction quasiSigmoidDecay(final double initValue, + final double slope, + final long numCall) { + return new NeighbourhoodSizeFunction() { + /** DecayFunction. */ + private final QuasiSigmoidDecayFunction decay + = new QuasiSigmoidDecayFunction(initValue, slope, numCall); + + /** {@inheritDoc} */ + public int value(long n) { + return (int) FastMath.rint(decay.value(n)); + } + }; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java new file mode 100644 index 0000000..60c3c61 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Self Organizing Feature Map. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java new file mode 100644 index 0000000..19e7380 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.util.FastMath; + +/** + * Exponential decay function: <code>a e<sup>-x / b</sup></code>, + * where {@code x} is the (integer) independent variable. + * <br/> + * Class is immutable. + * + * @since 3.3 + */ +public class ExponentialDecayFunction { + /** Factor {@code a}. */ + private final double a; + /** Factor {@code 1 / b}. */ + private final double oneOverB; + + /** + * Creates an instance. It will be such that + * <ul> + * <li>{@code a = initValue}</li> + * <li>{@code b = -numCall / ln(valueAtNumCall / initValue)}</li> + * </ul> + * + * @param initValue Initial value, i.e. {@link #value(long) value(0)}. + * @param valueAtNumCall Value of the function at {@code numCall}. + * @param numCall Argument for which the function returns + * {@code valueAtNumCall}. + * @throws NotStrictlyPositiveException if {@code initValue <= 0}. + * @throws NotStrictlyPositiveException if {@code valueAtNumCall <= 0}. + * @throws NumberIsTooLargeException if {@code valueAtNumCall >= initValue}. + * @throws NotStrictlyPositiveException if {@code numCall <= 0}. + */ + public ExponentialDecayFunction(double initValue, + double valueAtNumCall, + long numCall) { + if (initValue <= 0) { + throw new NotStrictlyPositiveException(initValue); + } + if (valueAtNumCall <= 0) { + throw new NotStrictlyPositiveException(valueAtNumCall); + } + if (valueAtNumCall >= initValue) { + throw new NumberIsTooLargeException(valueAtNumCall, initValue, false); + } + if (numCall <= 0) { + throw new NotStrictlyPositiveException(numCall); + } + + a = initValue; + oneOverB = -FastMath.log(valueAtNumCall / initValue) / numCall; + } + + /** + * Computes <code>a e<sup>-numCall / b</sup></code>. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + public double value(long numCall) { + return a * FastMath.exp(-numCall * oneOverB); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java new file mode 100644 index 0000000..3d35c17 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NumberIsTooLargeException; +import org.apache.commons.math3.analysis.function.Logistic; + +/** + * Decay function whose shape is similar to a sigmoid. + * <br/> + * Class is immutable. + * + * @since 3.3 + */ +public class QuasiSigmoidDecayFunction { + /** Sigmoid. */ + private final Logistic sigmoid; + /** See {@link #value(long)}. */ + private final double scale; + + /** + * Creates an instance. + * The function {@code f} will have the following properties: + * <ul> + * <li>{@code f(0) = initValue}</li> + * <li>{@code numCall} is the inflexion point</li> + * <li>{@code slope = f'(numCall)}</li> + * </ul> + * + * @param initValue Initial value, i.e. {@link #value(long) value(0)}. + * @param slope Value of the function derivative at {@code numCall}. + * @param numCall Inflexion point. + * @throws NotStrictlyPositiveException if {@code initValue <= 0}. + * @throws NumberIsTooLargeException if {@code slope >= 0}. + * @throws NotStrictlyPositiveException if {@code numCall <= 0}. + */ + public QuasiSigmoidDecayFunction(double initValue, + double slope, + long numCall) { + if (initValue <= 0) { + throw new NotStrictlyPositiveException(initValue); + } + if (slope >= 0) { + throw new NumberIsTooLargeException(slope, 0, false); + } + if (numCall <= 1) { + throw new NotStrictlyPositiveException(numCall); + } + + final double k = initValue; + final double m = numCall; + final double b = 4 * slope / initValue; + final double q = 1; + final double a = 0; + final double n = 1; + sigmoid = new Logistic(k, m, b, q, a, n); + + final double y0 = sigmoid.value(0); + scale = k / y0; + } + + /** + * Computes the value of the learning factor. + * + * @param numCall Current step of the training task. + * @return the value of the function at {@code numCall}. + */ + public double value(long numCall) { + return scale * sigmoid.value(numCall); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java new file mode 100644 index 0000000..5078ed2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Miscellaneous utilities. + */ + +package org.apache.commons.math3.ml.neuralnet.sofm.util; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java new file mode 100644 index 0000000..5277bc5 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod; + +import java.util.List; +import java.util.ArrayList; +import java.util.Iterator; +import java.io.Serializable; +import java.io.ObjectInputStream; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.FeatureInitializer; +import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.OutOfRangeException; +import org.apache.commons.math3.exception.MathInternalError; + +/** + * Neural network with the topology of a two-dimensional surface. + * Each neuron defines one surface element. + * <br/> + * This network is primarily intended to represent a + * <a href="http://en.wikipedia.org/wiki/Kohonen"> + * Self Organizing Feature Map</a>. + * + * @see org.apache.commons.math3.ml.neuralnet.sofm + * @since 3.3 + */ +public class NeuronSquareMesh2D + implements Iterable<Neuron>, + Serializable { + /** Serial version ID */ + private static final long serialVersionUID = 1L; + /** Underlying network. */ + private final Network network; + /** Number of rows. */ + private final int numberOfRows; + /** Number of columns. */ + private final int numberOfColumns; + /** Wrap. */ + private final boolean wrapRows; + /** Wrap. */ + private final boolean wrapColumns; + /** Neighbourhood type. */ + private final SquareNeighbourhood neighbourhood; + /** + * Mapping of the 2D coordinates (in the rectangular mesh) to + * the neuron identifiers (attributed by the {@link #network} + * instance). + */ + private final long[][] identifiers; + + /** + * Horizontal (along row) direction. + * @since 3.6 + */ + public enum HorizontalDirection { + /** Column at the right of the current column. */ + RIGHT, + /** Current column. */ + CENTER, + /** Column at the left of the current column. */ + LEFT, + } + /** + * Vertical (along column) direction. + * @since 3.6 + */ + public enum VerticalDirection { + /** Row above the current row. */ + UP, + /** Current row. */ + CENTER, + /** Row below the current row. */ + DOWN, + } + + /** + * Constructor with restricted access, solely used for deserialization. + * + * @param wrapRowDim Whether to wrap the first dimension (i.e the first + * and last neurons will be linked together). + * @param wrapColDim Whether to wrap the second dimension (i.e the first + * and last neurons will be linked together). + * @param neighbourhoodType Neighbourhood type. + * @param featuresList Arrays that will initialize the features sets of + * the network's neurons. + * @throws NumberIsTooSmallException if {@code numRows < 2} or + * {@code numCols < 2}. + */ + NeuronSquareMesh2D(boolean wrapRowDim, + boolean wrapColDim, + SquareNeighbourhood neighbourhoodType, + double[][][] featuresList) { + numberOfRows = featuresList.length; + numberOfColumns = featuresList[0].length; + + if (numberOfRows < 2) { + throw new NumberIsTooSmallException(numberOfRows, 2, true); + } + if (numberOfColumns < 2) { + throw new NumberIsTooSmallException(numberOfColumns, 2, true); + } + + wrapRows = wrapRowDim; + wrapColumns = wrapColDim; + neighbourhood = neighbourhoodType; + + final int fLen = featuresList[0][0].length; + network = new Network(0, fLen); + identifiers = new long[numberOfRows][numberOfColumns]; + + // Add neurons. + for (int i = 0; i < numberOfRows; i++) { + for (int j = 0; j < numberOfColumns; j++) { + identifiers[i][j] = network.createNeuron(featuresList[i][j]); + } + } + + // Add links. + createLinks(); + } + + /** + * Creates a two-dimensional network composed of square cells: + * Each neuron not located on the border of the mesh has four + * neurons linked to it. + * <br/> + * The links are bi-directional. + * <br/> + * The topology of the network can also be a cylinder (if one + * of the dimensions is wrapped) or a torus (if both dimensions + * are wrapped). + * + * @param numRows Number of neurons in the first dimension. + * @param wrapRowDim Whether to wrap the first dimension (i.e the first + * and last neurons will be linked together). + * @param numCols Number of neurons in the second dimension. + * @param wrapColDim Whether to wrap the second dimension (i.e the first + * and last neurons will be linked together). + * @param neighbourhoodType Neighbourhood type. + * @param featureInit Array of functions that will initialize the + * corresponding element of the features set of each newly created + * neuron. In particular, the size of this array defines the size of + * feature set. + * @throws NumberIsTooSmallException if {@code numRows < 2} or + * {@code numCols < 2}. + */ + public NeuronSquareMesh2D(int numRows, + boolean wrapRowDim, + int numCols, + boolean wrapColDim, + SquareNeighbourhood neighbourhoodType, + FeatureInitializer[] featureInit) { + if (numRows < 2) { + throw new NumberIsTooSmallException(numRows, 2, true); + } + if (numCols < 2) { + throw new NumberIsTooSmallException(numCols, 2, true); + } + + numberOfRows = numRows; + wrapRows = wrapRowDim; + numberOfColumns = numCols; + wrapColumns = wrapColDim; + neighbourhood = neighbourhoodType; + identifiers = new long[numberOfRows][numberOfColumns]; + + final int fLen = featureInit.length; + network = new Network(0, fLen); + + // Add neurons. + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + final double[] features = new double[fLen]; + for (int fIndex = 0; fIndex < fLen; fIndex++) { + features[fIndex] = featureInit[fIndex].value(); + } + identifiers[i][j] = network.createNeuron(features); + } + } + + // Add links. + createLinks(); + } + + /** + * Constructor with restricted access, solely used for making a + * {@link #copy() deep copy}. + * + * @param wrapRowDim Whether to wrap the first dimension (i.e the first + * and last neurons will be linked together). + * @param wrapColDim Whether to wrap the second dimension (i.e the first + * and last neurons will be linked together). + * @param neighbourhoodType Neighbourhood type. + * @param net Underlying network. + * @param idGrid Neuron identifiers. + */ + private NeuronSquareMesh2D(boolean wrapRowDim, + boolean wrapColDim, + SquareNeighbourhood neighbourhoodType, + Network net, + long[][] idGrid) { + numberOfRows = idGrid.length; + numberOfColumns = idGrid[0].length; + wrapRows = wrapRowDim; + wrapColumns = wrapColDim; + neighbourhood = neighbourhoodType; + network = net; + identifiers = idGrid; + } + + /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + * @since 3.6 + */ + public synchronized NeuronSquareMesh2D copy() { + final long[][] idGrid = new long[numberOfRows][numberOfColumns]; + for (int r = 0; r < numberOfRows; r++) { + for (int c = 0; c < numberOfColumns; c++) { + idGrid[r][c] = identifiers[r][c]; + } + } + + return new NeuronSquareMesh2D(wrapRows, + wrapColumns, + neighbourhood, + network.copy(), + idGrid); + } + + /** + * {@inheritDoc} + * @since 3.6 + */ + public Iterator<Neuron> iterator() { + return network.iterator(); + } + + /** + * Retrieves the underlying network. + * A reference is returned (enabling, for example, the network to be + * trained). + * This also implies that calling methods that modify the {@link Network} + * topology may cause this class to become inconsistent. + * + * @return the network. + */ + public Network getNetwork() { + return network; + } + + /** + * Gets the number of neurons in each row of this map. + * + * @return the number of rows. + */ + public int getNumberOfRows() { + return numberOfRows; + } + + /** + * Gets the number of neurons in each column of this map. + * + * @return the number of column. + */ + public int getNumberOfColumns() { + return numberOfColumns; + } + + /** + * Retrieves the neuron at location {@code (i, j)} in the map. + * The neuron at position {@code (0, 0)} is located at the upper-left + * corner of the map. + * + * @param i Row index. + * @param j Column index. + * @return the neuron at {@code (i, j)}. + * @throws OutOfRangeException if {@code i} or {@code j} is + * out of range. + * + * @see #getNeuron(int,int,HorizontalDirection,VerticalDirection) + */ + public Neuron getNeuron(int i, + int j) { + if (i < 0 || + i >= numberOfRows) { + throw new OutOfRangeException(i, 0, numberOfRows - 1); + } + if (j < 0 || + j >= numberOfColumns) { + throw new OutOfRangeException(j, 0, numberOfColumns - 1); + } + + return network.getNeuron(identifiers[i][j]); + } + + /** + * Retrieves the neuron at {@code (location[0], location[1])} in the map. + * The neuron at position {@code (0, 0)} is located at the upper-left + * corner of the map. + * + * @param row Row index. + * @param col Column index. + * @param alongRowDir Direction along the given {@code row} (i.e. an + * offset will be added to the given <em>column</em> index. + * @param alongColDir Direction along the given {@code col} (i.e. an + * offset will be added to the given <em>row</em> index. + * @return the neuron at the requested location, or {@code null} if + * the location is not on the map. + * + * @see #getNeuron(int,int) + */ + public Neuron getNeuron(int row, + int col, + HorizontalDirection alongRowDir, + VerticalDirection alongColDir) { + final int[] location = getLocation(row, col, alongRowDir, alongColDir); + + return location == null ? null : getNeuron(location[0], location[1]); + } + + /** + * Computes the location of a neighbouring neuron. + * It will return {@code null} if the resulting location is not part + * of the map. + * Position {@code (0, 0)} is at the upper-left corner of the map. + * + * @param row Row index. + * @param col Column index. + * @param alongRowDir Direction along the given {@code row} (i.e. an + * offset will be added to the given <em>column</em> index. + * @param alongColDir Direction along the given {@code col} (i.e. an + * offset will be added to the given <em>row</em> index. + * @return an array of length 2 containing the indices of the requested + * location, or {@code null} if that location is not part of the map. + * + * @see #getNeuron(int,int) + */ + private int[] getLocation(int row, + int col, + HorizontalDirection alongRowDir, + VerticalDirection alongColDir) { + final int colOffset; + switch (alongRowDir) { + case LEFT: + colOffset = -1; + break; + case RIGHT: + colOffset = 1; + break; + case CENTER: + colOffset = 0; + break; + default: + // Should never happen. + throw new MathInternalError(); + } + int colIndex = col + colOffset; + if (wrapColumns) { + if (colIndex < 0) { + colIndex += numberOfColumns; + } else { + colIndex %= numberOfColumns; + } + } + + final int rowOffset; + switch (alongColDir) { + case UP: + rowOffset = -1; + break; + case DOWN: + rowOffset = 1; + break; + case CENTER: + rowOffset = 0; + break; + default: + // Should never happen. + throw new MathInternalError(); + } + int rowIndex = row + rowOffset; + if (wrapRows) { + if (rowIndex < 0) { + rowIndex += numberOfRows; + } else { + rowIndex %= numberOfRows; + } + } + + if (rowIndex < 0 || + rowIndex >= numberOfRows || + colIndex < 0 || + colIndex >= numberOfColumns) { + return null; + } else { + return new int[] { rowIndex, colIndex }; + } + } + + /** + * Creates the neighbour relationships between neurons. + */ + private void createLinks() { + // "linkEnd" will store the identifiers of the "neighbours". + final List<Long> linkEnd = new ArrayList<Long>(); + final int iLast = numberOfRows - 1; + final int jLast = numberOfColumns - 1; + for (int i = 0; i < numberOfRows; i++) { + for (int j = 0; j < numberOfColumns; j++) { + linkEnd.clear(); + + switch (neighbourhood) { + + case MOORE: + // Add links to "diagonal" neighbours. + if (i > 0) { + if (j > 0) { + linkEnd.add(identifiers[i - 1][j - 1]); + } + if (j < jLast) { + linkEnd.add(identifiers[i - 1][j + 1]); + } + } + if (i < iLast) { + if (j > 0) { + linkEnd.add(identifiers[i + 1][j - 1]); + } + if (j < jLast) { + linkEnd.add(identifiers[i + 1][j + 1]); + } + } + if (wrapRows) { + if (i == 0) { + if (j > 0) { + linkEnd.add(identifiers[iLast][j - 1]); + } + if (j < jLast) { + linkEnd.add(identifiers[iLast][j + 1]); + } + } else if (i == iLast) { + if (j > 0) { + linkEnd.add(identifiers[0][j - 1]); + } + if (j < jLast) { + linkEnd.add(identifiers[0][j + 1]); + } + } + } + if (wrapColumns) { + if (j == 0) { + if (i > 0) { + linkEnd.add(identifiers[i - 1][jLast]); + } + if (i < iLast) { + linkEnd.add(identifiers[i + 1][jLast]); + } + } else if (j == jLast) { + if (i > 0) { + linkEnd.add(identifiers[i - 1][0]); + } + if (i < iLast) { + linkEnd.add(identifiers[i + 1][0]); + } + } + } + if (wrapRows && + wrapColumns) { + if (i == 0 && + j == 0) { + linkEnd.add(identifiers[iLast][jLast]); + } else if (i == 0 && + j == jLast) { + linkEnd.add(identifiers[iLast][0]); + } else if (i == iLast && + j == 0) { + linkEnd.add(identifiers[0][jLast]); + } else if (i == iLast && + j == jLast) { + linkEnd.add(identifiers[0][0]); + } + } + + // Case falls through since the "Moore" neighbourhood + // also contains the neurons that belong to the "Von + // Neumann" neighbourhood. + + // fallthru (CheckStyle) + case VON_NEUMANN: + // Links to preceding and following "row". + if (i > 0) { + linkEnd.add(identifiers[i - 1][j]); + } + if (i < iLast) { + linkEnd.add(identifiers[i + 1][j]); + } + if (wrapRows) { + if (i == 0) { + linkEnd.add(identifiers[iLast][j]); + } else if (i == iLast) { + linkEnd.add(identifiers[0][j]); + } + } + + // Links to preceding and following "column". + if (j > 0) { + linkEnd.add(identifiers[i][j - 1]); + } + if (j < jLast) { + linkEnd.add(identifiers[i][j + 1]); + } + if (wrapColumns) { + if (j == 0) { + linkEnd.add(identifiers[i][jLast]); + } else if (j == jLast) { + linkEnd.add(identifiers[i][0]); + } + } + break; + + default: + throw new MathInternalError(); // Cannot happen. + } + + final Neuron aNeuron = network.getNeuron(identifiers[i][j]); + for (long b : linkEnd) { + final Neuron bNeuron = network.getNeuron(b); + // Link to all neighbours. + // The reverse links will be added as the loop proceeds. + network.addLink(aNeuron, bNeuron); + } + } + } + } + + /** + * Prevents proxy bypass. + * + * @param in Input stream. + */ + private void readObject(ObjectInputStream in) { + throw new IllegalStateException(); + } + + /** + * Custom serialization. + * + * @return the proxy instance that will be actually serialized. + */ + private Object writeReplace() { + final double[][][] featuresList = new double[numberOfRows][numberOfColumns][]; + for (int i = 0; i < numberOfRows; i++) { + for (int j = 0; j < numberOfColumns; j++) { + featuresList[i][j] = getNeuron(i, j).getFeatures(); + } + } + + return new SerializationProxy(wrapRows, + wrapColumns, + neighbourhood, + featuresList); + } + + /** + * Serialization. + */ + private static class SerializationProxy implements Serializable { + /** Serializable. */ + private static final long serialVersionUID = 20130226L; + /** Wrap. */ + private final boolean wrapRows; + /** Wrap. */ + private final boolean wrapColumns; + /** Neighbourhood type. */ + private final SquareNeighbourhood neighbourhood; + /** Neurons' features. */ + private final double[][][] featuresList; + + /** + * @param wrapRows Whether the row dimension is wrapped. + * @param wrapColumns Whether the column dimension is wrapped. + * @param neighbourhood Neighbourhood type. + * @param featuresList List of neurons features. + * {@code neuronList}. + */ + SerializationProxy(boolean wrapRows, + boolean wrapColumns, + SquareNeighbourhood neighbourhood, + double[][][] featuresList) { + this.wrapRows = wrapRows; + this.wrapColumns = wrapColumns; + this.neighbourhood = neighbourhood; + this.featuresList = featuresList; + } + + /** + * Custom serialization. + * + * @return the {@link Neuron} for which this instance is the proxy. + */ + private Object readResolve() { + return new NeuronSquareMesh2D(wrapRows, + wrapColumns, + neighbourhood, + featuresList); + } + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java new file mode 100644 index 0000000..41535e8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Two-dimensional neural networks. + */ + +package org.apache.commons.math3.ml.neuralnet.twod; diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java new file mode 100644 index 0000000..06cee98 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/HitHistogram.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * Computes the hit histogram. + * Each bin will contain the number of data for which the corresponding + * neuron is the best matching unit. + * @since 3.6 + */ +public class HitHistogram implements MapDataVisualization { + /** Distance. */ + private final DistanceMeasure distance; + /** Whether to compute relative bin counts. */ + private final boolean normalizeCount; + + /** + * @param normalizeCount Whether to compute relative bin counts. + * If {@code true}, the data count in each bin will be divided by the total + * number of samples. + * @param distance Distance. + */ + public HitHistogram(boolean normalizeCount, + DistanceMeasure distance) { + this.normalizeCount = normalizeCount; + this.distance = distance; + } + + /** {@inheritDoc} */ + public double[][] computeImage(NeuronSquareMesh2D map, + Iterable<double[]> data) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + final LocationFinder finder = new LocationFinder(map); + + // Total number of samples. + int numSamples = 0; + // Hit bins. + final double[][] hit = new double[nR][nC]; + + for (double[] sample : data) { + final Neuron best = MapUtils.findBest(sample, map, distance); + + final LocationFinder.Location loc = finder.getLocation(best); + final int row = loc.getRow(); + final int col = loc.getColumn(); + hit[row][col] += 1; + + ++numSamples; + } + + if (normalizeCount) { + for (int r = 0; r < nR; r++) { + for (int c = 0; c < nC; c++) { + hit[r][c] /= numSamples; + } + } + } + + return hit; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java new file mode 100644 index 0000000..e4ece61 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/LocationFinder.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import java.util.Map; +import java.util.HashMap; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.exception.MathIllegalStateException; + +/** + * Helper class to find the grid coordinates of a neuron. + * @since 3.6 + */ +public class LocationFinder { + /** Identifier to location mapping. */ + private final Map<Long, Location> locations = new HashMap<Long, Location>(); + + /** + * Container holding a (row, column) pair. + */ + public static class Location { + /** Row index. */ + private final int row; + /** Column index. */ + private final int column; + + /** + * @param row Row index. + * @param column Column index. + */ + public Location(int row, + int column) { + this.row = row; + this.column = column; + } + + /** + * @return the row index. + */ + public int getRow() { + return row; + } + + /** + * @return the column index. + */ + public int getColumn() { + return column; + } + } + + /** + * Builds a finder to retrieve the locations of neurons that + * belong to the given {@code map}. + * + * @param map Map. + * + * @throws MathIllegalStateException if the network contains non-unique + * identifiers. This indicates an inconsistent state due to a bug in + * the construction code of the underlying + * {@link org.apache.commons.math3.ml.neuralnet.Network network}. + */ + public LocationFinder(NeuronSquareMesh2D map) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + for (int r = 0; r < nR; r++) { + for (int c = 0; c < nC; c++) { + final Long id = map.getNeuron(r, c).getIdentifier(); + if (locations.get(id) != null) { + throw new MathIllegalStateException(); + } + locations.put(id, new Location(r, c)); + } + } + } + + /** + * Retrieves a neuron's grid coordinates. + * + * @param n Neuron. + * @return the (row, column) coordinates of {@code n}, or {@code null} + * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D) + * map used to build this instance}. + */ + public Location getLocation(Neuron n) { + return locations.get(n.getIdentifier()); + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java new file mode 100644 index 0000000..71fab43 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapDataVisualization.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; + +/** + * Interface for algorithms that compute some metrics of the projection of + * data on a 2D-map. + * @since 3.6 + */ +public interface MapDataVisualization { + /** + * Creates an image of the {@code data} metrics when represented by the + * {@code map}. + * + * @param map Map. + * @param data Data. + * @return a 2D-array (in row major order) representing the metrics. + */ + double[][] computeImage(NeuronSquareMesh2D map, + Iterable<double[]> data); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java new file mode 100644 index 0000000..9304d76 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/MapVisualization.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; + +/** + * Interface for algorithms that compute some property of a 2D-map. + * @since 3.6 + */ +public interface MapVisualization { + /** + * Creates an image of the {@code map}. + * + * @param map Map. + * @return a 2D-array (in row major order) representing the property. + */ + double[][] computeImage(NeuronSquareMesh2D map); +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java new file mode 100644 index 0000000..8ec1da3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/QuantizationError.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * Computes the quantization error histogram. + * Each bin will contain the average of the distances between samples + * mapped to the corresponding unit and the weight vector of that unit. + * @since 3.6 + */ +public class QuantizationError implements MapDataVisualization { + /** Distance. */ + private final DistanceMeasure distance; + + /** + * @param distance Distance. + */ + public QuantizationError(DistanceMeasure distance) { + this.distance = distance; + } + + /** {@inheritDoc} */ + public double[][] computeImage(NeuronSquareMesh2D map, + Iterable<double[]> data) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + final LocationFinder finder = new LocationFinder(map); + + // Hit bins. + final int[][] hit = new int[nR][nC]; + // Error bins. + final double[][] error = new double[nR][nC]; + + for (double[] sample : data) { + final Neuron best = MapUtils.findBest(sample, map, distance); + + final LocationFinder.Location loc = finder.getLocation(best); + final int row = loc.getRow(); + final int col = loc.getColumn(); + hit[row][col] += 1; + error[row][col] += distance.compute(sample, best.getFeatures()); + } + + for (int r = 0; r < nR; r++) { + for (int c = 0; c < nC; c++) { + final int count = hit[r][c]; + if (count != 0) { + error[r][c] /= count; + } + } + } + + return error; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java new file mode 100644 index 0000000..b8e552c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/SmoothedDataHistogram.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.exception.NumberIsTooSmallException; + +/** + * Visualization of high-dimensional data projection on a 2D-map. + * The method is described in + * <quote> + * <em>Using Smoothed Data Histograms for Cluster Visualization in Self-Organizing Maps</em> + * <br> + * by Elias Pampalk, Andreas Rauber and Dieter Merkl. + * </quote> + * @since 3.6 + */ +public class SmoothedDataHistogram implements MapDataVisualization { + /** Smoothing parameter. */ + private final int smoothingBins; + /** Distance. */ + private final DistanceMeasure distance; + /** Normalization factor. */ + private final double membershipNormalization; + + /** + * @param smoothingBins Number of bins. + * @param distance Distance. + */ + public SmoothedDataHistogram(int smoothingBins, + DistanceMeasure distance) { + this.smoothingBins = smoothingBins; + this.distance = distance; + + double sum = 0; + for (int i = 0; i < smoothingBins; i++) { + sum += smoothingBins - i; + } + + this.membershipNormalization = 1d / sum; + } + + /** + * {@inheritDoc} + * + * @throws NumberIsTooSmallException if the size of the {@code map} + * is smaller than the number of {@link #SmoothedDataHistogram(int,DistanceMeasure) + * smoothing bins}. + */ + public double[][] computeImage(NeuronSquareMesh2D map, + Iterable<double[]> data) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + final int mapSize = nR * nC; + if (mapSize < smoothingBins) { + throw new NumberIsTooSmallException(mapSize, smoothingBins, true); + } + + final LocationFinder finder = new LocationFinder(map); + + // Histogram bins. + final double[][] histo = new double[nR][nC]; + + for (double[] sample : data) { + final Neuron[] sorted = MapUtils.sort(sample, + map.getNetwork(), + distance); + for (int i = 0; i < smoothingBins; i++) { + final LocationFinder.Location loc = finder.getLocation(sorted[i]); + final int row = loc.getRow(); + final int col = loc.getColumn(); + histo[row][col] += (smoothingBins - i) * membershipNormalization; + } + } + + return histo; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java new file mode 100644 index 0000000..b831de8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/TopographicErrorHistogram.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.util.Pair; + +/** + * Computes the topographic error histogram. + * Each bin will contain the number of data for which the first and + * second best matching units are not adjacent in the map. + * @since 3.6 + */ +public class TopographicErrorHistogram implements MapDataVisualization { + /** Distance. */ + private final DistanceMeasure distance; + /** Whether to compute relative bin counts. */ + private final boolean relativeCount; + + /** + * @param relativeCount Whether to compute relative bin counts. + * If {@code true}, the data count in each bin will be divided by the total + * number of samples mapped to the neuron represented by that bin. + * @param distance Distance. + */ + public TopographicErrorHistogram(boolean relativeCount, + DistanceMeasure distance) { + this.relativeCount = relativeCount; + this.distance = distance; + } + + /** {@inheritDoc} */ + public double[][] computeImage(NeuronSquareMesh2D map, + Iterable<double[]> data) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + final Network net = map.getNetwork(); + final LocationFinder finder = new LocationFinder(map); + + // Hit bins. + final int[][] hit = new int[nR][nC]; + // Error bins. + final double[][] error = new double[nR][nC]; + + for (double[] sample : data) { + final Pair<Neuron, Neuron> p = MapUtils.findBestAndSecondBest(sample, map, distance); + final Neuron best = p.getFirst(); + + final LocationFinder.Location loc = finder.getLocation(best); + final int row = loc.getRow(); + final int col = loc.getColumn(); + hit[row][col] += 1; + + if (!net.getNeighbours(best).contains(p.getSecond())) { + // Increment count if first and second best matching units + // are not neighbours. + error[row][col] += 1; + } + } + + if (relativeCount) { + for (int r = 0; r < nR; r++) { + for (int c = 0; c < nC; c++) { + error[r][c] /= hit[r][c]; + } + } + } + + return error; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java new file mode 100644 index 0000000..aee982a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/UnifiedDistanceMatrix.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; + +import java.util.Collection; +import org.apache.commons.math3.ml.neuralnet.Neuron; +import org.apache.commons.math3.ml.neuralnet.Network; +import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * <a href="http://en.wikipedia.org/wiki/U-Matrix">U-Matrix</a> + * visualization of high-dimensional data projection. + * @since 3.6 + */ +public class UnifiedDistanceMatrix implements MapVisualization { + /** Whether to show distance between each pair of neighbouring units. */ + private final boolean individualDistances; + /** Distance. */ + private final DistanceMeasure distance; + + /** + * Simple constructor. + * + * @param individualDistances If {@code true}, the 8 individual + * inter-units distances will be {@link #computeImage(NeuronSquareMesh2D) + * computed}. They will be stored in additional pixels around each of + * the original units of the 2D-map. The additional pixels that lie + * along a "diagonal" are shared by <em>two</em> pairs of units: their + * value will be set to the average distance between the units belonging + * to each of the pairs. The value zero will be stored in the pixel + * corresponding to the location of a unit of the 2D-map. + * <br> + * If {@code false}, only the average distance between a unit and all its + * neighbours will be computed (and stored in the pixel corresponding to + * that unit of the 2D-map). In that case, the number of neighbours taken + * into account depends on the network's + * {@link org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood + * neighbourhood type}. + * @param distance Distance. + */ + public UnifiedDistanceMatrix(boolean individualDistances, + DistanceMeasure distance) { + this.individualDistances = individualDistances; + this.distance = distance; + } + + /** {@inheritDoc} */ + public double[][] computeImage(NeuronSquareMesh2D map) { + if (individualDistances) { + return individualDistances(map); + } else { + return averageDistances(map); + } + } + + /** + * Computes the distances between a unit of the map and its + * neighbours. + * The image will contain more pixels than the number of neurons + * in the given {@code map} because each neuron has 8 neighbours. + * The value zero will be stored in the pixels corresponding to + * the location of a map unit. + * + * @param map Map. + * @return an image representing the individual distances. + */ + private double[][] individualDistances(NeuronSquareMesh2D map) { + final int numRows = map.getNumberOfRows(); + final int numCols = map.getNumberOfColumns(); + + final double[][] uMatrix = new double[numRows * 2 + 1][numCols * 2 + 1]; + + // 1. + // Fill right and bottom slots of each unit's location with the + // distance between the current unit and each of the two neighbours, + // respectively. + for (int i = 0; i < numRows; i++) { + // Current unit's row index in result image. + final int iR = 2 * i + 1; + + for (int j = 0; j < numCols; j++) { + // Current unit's column index in result image. + final int jR = 2 * j + 1; + + final double[] current = map.getNeuron(i, j).getFeatures(); + Neuron neighbour; + + // Right neighbour. + neighbour = map.getNeuron(i, j, + NeuronSquareMesh2D.HorizontalDirection.RIGHT, + NeuronSquareMesh2D.VerticalDirection.CENTER); + if (neighbour != null) { + uMatrix[iR][jR + 1] = distance.compute(current, + neighbour.getFeatures()); + } + + // Bottom-center neighbour. + neighbour = map.getNeuron(i, j, + NeuronSquareMesh2D.HorizontalDirection.CENTER, + NeuronSquareMesh2D.VerticalDirection.DOWN); + if (neighbour != null) { + uMatrix[iR + 1][jR] = distance.compute(current, + neighbour.getFeatures()); + } + } + } + + // 2. + // Fill the bottom-rigth slot of each unit's location with the average + // of the distances between + // * the current unit and its bottom-right neighbour, and + // * the bottom-center neighbour and the right neighbour. + for (int i = 0; i < numRows; i++) { + // Current unit's row index in result image. + final int iR = 2 * i + 1; + + for (int j = 0; j < numCols; j++) { + // Current unit's column index in result image. + final int jR = 2 * j + 1; + + final Neuron current = map.getNeuron(i, j); + final Neuron right = map.getNeuron(i, j, + NeuronSquareMesh2D.HorizontalDirection.RIGHT, + NeuronSquareMesh2D.VerticalDirection.CENTER); + final Neuron bottom = map.getNeuron(i, j, + NeuronSquareMesh2D.HorizontalDirection.CENTER, + NeuronSquareMesh2D.VerticalDirection.DOWN); + final Neuron bottomRight = map.getNeuron(i, j, + NeuronSquareMesh2D.HorizontalDirection.RIGHT, + NeuronSquareMesh2D.VerticalDirection.DOWN); + + final double current2BottomRight = bottomRight == null ? + 0 : + distance.compute(current.getFeatures(), + bottomRight.getFeatures()); + final double right2Bottom = (right == null || + bottom == null) ? + 0 : + distance.compute(right.getFeatures(), + bottom.getFeatures()); + + // Bottom-right slot. + uMatrix[iR + 1][jR + 1] = 0.5 * (current2BottomRight + right2Bottom); + } + } + + // 3. Copy last row into first row. + final int lastRow = uMatrix.length - 1; + uMatrix[0] = uMatrix[lastRow]; + + // 4. + // Copy last column into first column. + final int lastCol = uMatrix[0].length - 1; + for (int r = 0; r < lastRow; r++) { + uMatrix[r][0] = uMatrix[r][lastCol]; + } + + return uMatrix; + } + + /** + * Computes the distances between a unit of the map and its neighbours. + * + * @param map Map. + * @return an image representing the average distances. + */ + private double[][] averageDistances(NeuronSquareMesh2D map) { + final int numRows = map.getNumberOfRows(); + final int numCols = map.getNumberOfColumns(); + final double[][] uMatrix = new double[numRows][numCols]; + + final Network net = map.getNetwork(); + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + final Neuron neuron = map.getNeuron(i, j); + final Collection<Neuron> neighbours = net.getNeighbours(neuron); + final double[] features = neuron.getFeatures(); + + double d = 0; + int count = 0; + for (Neuron n : neighbours) { + ++count; + d += distance.compute(features, n.getFeatures()); + } + + uMatrix[i][j] = d / count; + } + } + + return uMatrix; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java new file mode 100644 index 0000000..cd4aab0 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Utilities to visualize two-dimensional neural networks. + */ + +package org.apache.commons.math3.ml.neuralnet.twod.util; diff --git a/src/main/java/org/apache/commons/math3/ml/package-info.java b/src/main/java/org/apache/commons/math3/ml/package-info.java new file mode 100644 index 0000000..394aad2 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/package-info.java @@ -0,0 +1,18 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** Base package for machine learning algorithms. */ +package org.apache.commons.math3.ml; |