diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/ml/clustering')
13 files changed, 1901 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java new file mode 100644 index 0000000..5cfc7bc --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/CentroidCluster.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +/** + * A Cluster used by centroid-based clustering algorithms. + * <p> + * Defines additionally a cluster center which may not necessarily be a member + * of the original data set. + * + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public class CentroidCluster<T extends Clusterable> extends Cluster<T> { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3075288519071812288L; + + /** Center of the cluster. */ + private final Clusterable center; + + /** + * Build a cluster centered at a specified point. + * @param center the point which is to be the center of this cluster + */ + public CentroidCluster(final Clusterable center) { + super(); + this.center = center; + } + + /** + * Get the point chosen to be the center of this cluster. + * @return chosen cluster center + */ + public Clusterable getCenter() { + return center; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java new file mode 100644 index 0000000..fa6df94 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Cluster.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Cluster holding a set of {@link Clusterable} points. + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public class Cluster<T extends Clusterable> implements Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = -3442297081515880464L; + + /** The points contained in this cluster. */ + private final List<T> points; + + /** + * Build a cluster centered at a specified point. + */ + public Cluster() { + points = new ArrayList<T>(); + } + + /** + * Add a point to this cluster. + * @param point point to add + */ + public void addPoint(final T point) { + points.add(point); + } + + /** + * Get the points contained in the cluster. + * @return points contained in the cluster + */ + public List<T> getPoints() { + return points; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java new file mode 100644 index 0000000..e712eb7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterable.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +/** + * Interface for n-dimensional points that can be clustered together. + * @since 3.2 + */ +public interface Clusterable { + + /** + * Gets the n-dimensional point. + * + * @return the point array + */ + double[] getPoint(); +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java new file mode 100644 index 0000000..30e38c6 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/Clusterer.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; + +/** + * Base class for clustering algorithms. + * + * @param <T> the type of points that can be clustered + * @since 3.2 + */ +public abstract class Clusterer<T extends Clusterable> { + + /** The distance measure to use. */ + private DistanceMeasure measure; + + /** + * Build a new clusterer with the given {@link DistanceMeasure}. + * + * @param measure the distance measure to use + */ + protected Clusterer(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Perform a cluster analysis on the given set of {@link Clusterable} instances. + * + * @param points the set of {@link Clusterable} instances + * @return a {@link List} of clusters + * @throws MathIllegalArgumentException if points are null or the number of + * data points is not compatible with this clusterer + * @throws ConvergenceException if the algorithm has not yet converged after + * the maximum number of iterations has been exceeded + */ + public abstract List<? extends Cluster<T>> cluster(Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException; + + /** + * Returns the {@link DistanceMeasure} instance used by this clusterer. + * + * @return the distance measure + */ + public DistanceMeasure getDistanceMeasure() { + return measure; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java new file mode 100644 index 0000000..ce3d5cd --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DBSCANClusterer.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.util.MathUtils; + +/** + * DBSCAN (density-based spatial clustering of applications with noise) algorithm. + * <p> + * The DBSCAN algorithm forms clusters based on the idea of density connectivity, i.e. + * a point p is density connected to another point q, if there exists a chain of + * points p<sub>i</sub>, with i = 1 .. n and p<sub>1</sub> = p and p<sub>n</sub> = q, + * such that each pair <p<sub>i</sub>, p<sub>i+1</sub>> is directly density-reachable. + * A point q is directly density-reachable from point p if it is in the ε-neighborhood + * of this point. + * <p> + * Any point that is not density-reachable from a formed cluster is treated as noise, and + * will thus not be present in the result. + * <p> + * The algorithm requires two parameters: + * <ul> + * <li>eps: the distance that defines the ε-neighborhood of a point + * <li>minPoints: the minimum number of density-connected points required to form a cluster + * </ul> + * + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/DBSCAN">DBSCAN (wikipedia)</a> + * @see <a href="http://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf"> + * A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise</a> + * @since 3.2 + */ +public class DBSCANClusterer<T extends Clusterable> extends Clusterer<T> { + + /** Maximum radius of the neighborhood to be considered. */ + private final double eps; + + /** Minimum number of points needed for a cluster. */ + private final int minPts; + + /** Status of a point during the clustering process. */ + private enum PointStatus { + /** The point has is considered to be noise. */ + NOISE, + /** The point is already part of a cluster. */ + PART_OF_CLUSTER + } + + /** + * Creates a new instance of a DBSCANClusterer. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts) + throws NotPositiveException { + this(eps, minPts, new EuclideanDistance()); + } + + /** + * Creates a new instance of a DBSCANClusterer. + * + * @param eps maximum radius of the neighborhood to be considered + * @param minPts minimum number of points needed for a cluster + * @param measure the distance measure to use + * @throws NotPositiveException if {@code eps < 0.0} or {@code minPts < 0} + */ + public DBSCANClusterer(final double eps, final int minPts, final DistanceMeasure measure) + throws NotPositiveException { + super(measure); + + if (eps < 0.0d) { + throw new NotPositiveException(eps); + } + if (minPts < 0) { + throw new NotPositiveException(minPts); + } + this.eps = eps; + this.minPts = minPts; + } + + /** + * Returns the maximum radius of the neighborhood to be considered. + * @return maximum radius of the neighborhood + */ + public double getEps() { + return eps; + } + + /** + * Returns the minimum number of points needed for a cluster. + * @return minimum number of points needed for a cluster + */ + public int getMinPts() { + return minPts; + } + + /** + * Performs DBSCAN cluster analysis. + * + * @param points the points to cluster + * @return the list of clusters + * @throws NullArgumentException if the data points are null + */ + @Override + public List<Cluster<T>> cluster(final Collection<T> points) throws NullArgumentException { + + // sanity checks + MathUtils.checkNotNull(points); + + final List<Cluster<T>> clusters = new ArrayList<Cluster<T>>(); + final Map<Clusterable, PointStatus> visited = new HashMap<Clusterable, PointStatus>(); + + for (final T point : points) { + if (visited.get(point) != null) { + continue; + } + final List<T> neighbors = getNeighbors(point, points); + if (neighbors.size() >= minPts) { + // DBSCAN does not care about center points + final Cluster<T> cluster = new Cluster<T>(); + clusters.add(expandCluster(cluster, point, neighbors, points, visited)); + } else { + visited.put(point, PointStatus.NOISE); + } + } + + return clusters; + } + + /** + * Expands the cluster to include density-reachable items. + * + * @param cluster Cluster to expand + * @param point Point to add to cluster + * @param neighbors List of neighbors + * @param points the data set + * @param visited the set of already visited points + * @return the expanded cluster + */ + private Cluster<T> expandCluster(final Cluster<T> cluster, + final T point, + final List<T> neighbors, + final Collection<T> points, + final Map<Clusterable, PointStatus> visited) { + cluster.addPoint(point); + visited.put(point, PointStatus.PART_OF_CLUSTER); + + List<T> seeds = new ArrayList<T>(neighbors); + int index = 0; + while (index < seeds.size()) { + final T current = seeds.get(index); + PointStatus pStatus = visited.get(current); + // only check non-visited points + if (pStatus == null) { + final List<T> currentNeighbors = getNeighbors(current, points); + if (currentNeighbors.size() >= minPts) { + seeds = merge(seeds, currentNeighbors); + } + } + + if (pStatus != PointStatus.PART_OF_CLUSTER) { + visited.put(current, PointStatus.PART_OF_CLUSTER); + cluster.addPoint(current); + } + + index++; + } + return cluster; + } + + /** + * Returns a list of density-reachable neighbors of a {@code point}. + * + * @param point the point to look for + * @param points possible neighbors + * @return the List of neighbors + */ + private List<T> getNeighbors(final T point, final Collection<T> points) { + final List<T> neighbors = new ArrayList<T>(); + for (final T neighbor : points) { + if (point != neighbor && distance(neighbor, point) <= eps) { + neighbors.add(neighbor); + } + } + return neighbors; + } + + /** + * Merges two lists together. + * + * @param one first list + * @param two second list + * @return merged lists + */ + private List<T> merge(final List<T> one, final List<T> two) { + final Set<T> oneSet = new HashSet<T>(one); + for (T item : two) { + if (!oneSet.contains(item)) { + one.add(item); + } + } + return one; + } +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java new file mode 100644 index 0000000..4fb31f7 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/DoublePoint.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * A simple implementation of {@link Clusterable} for points with double coordinates. + * @since 3.2 + */ +public class DoublePoint implements Clusterable, Serializable { + + /** Serializable version identifier. */ + private static final long serialVersionUID = 3946024775784901369L; + + /** Point coordinates. */ + private final double[] point; + + /** + * Build an instance wrapping an double array. + * <p> + * The wrapped array is referenced, it is <em>not</em> copied. + * + * @param point the n-dimensional point in double space + */ + public DoublePoint(final double[] point) { + this.point = point; + } + + /** + * Build an instance wrapping an integer array. + * <p> + * The wrapped array is copied to an internal double array. + * + * @param point the n-dimensional point in integer space + */ + public DoublePoint(final int[] point) { + this.point = new double[point.length]; + for ( int i = 0; i < point.length; i++) { + this.point[i] = point[i]; + } + } + + /** {@inheritDoc} */ + public double[] getPoint() { + return point; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(final Object other) { + if (!(other instanceof DoublePoint)) { + return false; + } + return Arrays.equals(point, ((DoublePoint) other).point); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Arrays.hashCode(point); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return Arrays.toString(point); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java new file mode 100644 index 0000000..5f89934 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/FuzzyKMeansClusterer.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.MathUtils; + +/** + * Fuzzy K-Means clustering algorithm. + * <p> + * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the + * major difference that a single data point is not uniquely assigned to a single cluster. + * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership + * to the cluster j. + * <p> + * The algorithm then tries to minimize the objective function: + * <pre> + * J = ∑<sub>i=1..C</sub>∑<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup> + * </pre> + * with d<sub>ik</sub> being the distance between data point i and the cluster center k. + * <p> + * The algorithm requires two parameters: + * <ul> + * <li>k: the number of clusters + * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters + * </ul> + * Additional, optional parameters: + * <ul> + * <li>maxIterations: the maximum number of iterations + * <li>epsilon: the convergence criteria, default is 1e-3 + * </ul> + * <p> + * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection + * of the initial cluster centers. + * + * @param <T> type of the points to cluster + * @since 3.3 + */ +public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> { + + /** The default value for the convergence criteria. */ + private static final double DEFAULT_EPSILON = 1e-3; + + /** The number of clusters. */ + private final int k; + + /** The maximum number of iterations. */ + private final int maxIterations; + + /** The fuzziness factor. */ + private final double fuzziness; + + /** The convergence criteria. */ + private final double epsilon; + + /** Random generator for choosing initial centers. */ + private final RandomGenerator random; + + /** The membership matrix. */ + private double[][] membershipMatrix; + + /** The list of points used in the last call to {@link #cluster(Collection)}. */ + private List<T> points; + + /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */ + private List<CentroidCluster<T>> clusters; + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException { + this(k, fuzziness, -1, new EuclideanDistance()); + } + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness, + final int maxIterations, final DistanceMeasure measure) + throws NumberIsTooSmallException { + this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator()); + } + + /** + * Creates a new instance of a FuzzyKMeansClusterer. + * + * @param k the number of clusters to split the data into + * @param fuzziness the fuzziness factor, must be > 1.0 + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param epsilon the convergence criteria (default is 1e-3) + * @param random random generator to use for choosing initial centers + * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} + */ + public FuzzyKMeansClusterer(final int k, final double fuzziness, + final int maxIterations, final DistanceMeasure measure, + final double epsilon, final RandomGenerator random) + throws NumberIsTooSmallException { + + super(measure); + + if (fuzziness <= 1.0d) { + throw new NumberIsTooSmallException(fuzziness, 1.0, false); + } + this.k = k; + this.fuzziness = fuzziness; + this.maxIterations = maxIterations; + this.epsilon = epsilon; + this.random = random; + + this.membershipMatrix = null; + this.points = null; + this.clusters = null; + } + + /** + * Return the number of clusters this instance will use. + * @return the number of clusters + */ + public int getK() { + return k; + } + + /** + * Returns the fuzziness factor used by this instance. + * @return the fuzziness factor + */ + public double getFuzziness() { + return fuzziness; + } + + /** + * Returns the maximum number of iterations this instance will use. + * @return the maximum number of iterations, or -1 if no maximum is set + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Returns the convergence criteria used by this instance. + * @return the convergence criteria + */ + public double getEpsilon() { + return epsilon; + } + + /** + * Returns the random generator this instance will use. + * @return the random generator + */ + public RandomGenerator getRandomGenerator() { + return random; + } + + /** + * Returns the {@code nxk} membership matrix, where {@code n} is the number + * of data points and {@code k} the number of clusters. + * <p> + * The element U<sub>i,j</sub> represents the membership value for data point {@code i} + * to cluster {@code j}. + * + * @return the membership matrix + * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before + */ + public RealMatrix getMembershipMatrix() { + if (membershipMatrix == null) { + throw new MathIllegalStateException(); + } + return MatrixUtils.createRealMatrix(membershipMatrix); + } + + /** + * Returns an unmodifiable list of the data points used in the last + * call to {@link #cluster(Collection)}. + * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has + * not been called before. + */ + public List<T> getDataPoints() { + return points; + } + + /** + * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}. + * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has + * not been called before. + */ + public List<CentroidCluster<T>> getClusters() { + return clusters; + } + + /** + * Get the value of the objective function. + * @return the objective function evaluation as double value + * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before + */ + public double getObjectiveFunctionValue() { + if (points == null || clusters == null) { + throw new MathIllegalStateException(); + } + + int i = 0; + double objFunction = 0.0; + for (final T point : points) { + int j = 0; + for (final CentroidCluster<T> cluster : clusters) { + final double dist = distance(point, cluster.getCenter()); + objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness); + j++; + } + i++; + } + return objFunction; + } + + /** + * Performs Fuzzy K-Means cluster analysis. + * + * @param dataPoints the points to cluster + * @return the list of clusters + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) + throws MathIllegalArgumentException { + + // sanity checks + MathUtils.checkNotNull(dataPoints); + + final int size = dataPoints.size(); + + // number of clusters has to be smaller or equal the number of data points + if (size < k) { + throw new NumberIsTooSmallException(size, k, false); + } + + // copy the input collection to an unmodifiable list with indexed access + points = Collections.unmodifiableList(new ArrayList<T>(dataPoints)); + clusters = new ArrayList<CentroidCluster<T>>(); + membershipMatrix = new double[size][k]; + final double[][] oldMatrix = new double[size][k]; + + // if no points are provided, return an empty list of clusters + if (size == 0) { + return clusters; + } + + initializeMembershipMatrix(); + + // there is at least one point + final int pointDimension = points.get(0).getPoint().length; + for (int i = 0; i < k; i++) { + clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension]))); + } + + int iteration = 0; + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + double difference = 0.0; + + do { + saveMembershipMatrix(oldMatrix); + updateClusterCenters(); + updateMembershipMatrix(); + difference = calculateMaxMembershipChange(oldMatrix); + } while (difference > epsilon && ++iteration < max); + + return clusters; + } + + /** + * Update the cluster centers. + */ + private void updateClusterCenters() { + int j = 0; + final List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(k); + for (final CentroidCluster<T> cluster : clusters) { + final Clusterable center = cluster.getCenter(); + int i = 0; + double[] arr = new double[center.getPoint().length]; + double sum = 0.0; + for (final T point : points) { + final double u = FastMath.pow(membershipMatrix[i][j], fuzziness); + final double[] pointArr = point.getPoint(); + for (int idx = 0; idx < arr.length; idx++) { + arr[idx] += u * pointArr[idx]; + } + sum += u; + i++; + } + MathArrays.scaleInPlace(1.0 / sum, arr); + newClusters.add(new CentroidCluster<T>(new DoublePoint(arr))); + j++; + } + clusters.clear(); + clusters = newClusters; + } + + /** + * Updates the membership matrix and assigns the points to the cluster with + * the highest membership. + */ + private void updateMembershipMatrix() { + for (int i = 0; i < points.size(); i++) { + final T point = points.get(i); + double maxMembership = Double.MIN_VALUE; + int newCluster = -1; + for (int j = 0; j < clusters.size(); j++) { + double sum = 0.0; + final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter())); + + if (distA != 0.0) { + for (final CentroidCluster<T> c : clusters) { + final double distB = FastMath.abs(distance(point, c.getCenter())); + if (distB == 0.0) { + sum = Double.POSITIVE_INFINITY; + break; + } + sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); + } + } + + double membership; + if (sum == 0.0) { + membership = 1.0; + } else if (sum == Double.POSITIVE_INFINITY) { + membership = 0.0; + } else { + membership = 1.0 / sum; + } + membershipMatrix[i][j] = membership; + + if (membershipMatrix[i][j] > maxMembership) { + maxMembership = membershipMatrix[i][j]; + newCluster = j; + } + } + clusters.get(newCluster).addPoint(point); + } + } + + /** + * Initialize the membership matrix with random values. + */ + private void initializeMembershipMatrix() { + for (int i = 0; i < points.size(); i++) { + for (int j = 0; j < k; j++) { + membershipMatrix[i][j] = random.nextDouble(); + } + membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0); + } + } + + /** + * Calculate the maximum element-by-element change of the membership matrix + * for the current iteration. + * + * @param matrix the membership matrix of the previous iteration + * @return the maximum membership matrix change + */ + private double calculateMaxMembershipChange(final double[][] matrix) { + double maxMembership = 0.0; + for (int i = 0; i < points.size(); i++) { + for (int j = 0; j < clusters.size(); j++) { + double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]); + maxMembership = FastMath.max(v, maxMembership); + } + } + return maxMembership; + } + + /** + * Copy the membership matrix into the provided matrix. + * + * @param matrix the place to store the membership matrix + */ + private void saveMembershipMatrix(final double[][] matrix) { + for (int i = 0; i < points.size(); i++) { + System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size()); + } + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java new file mode 100644 index 0000000..2e57fac --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.java @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.apache.commons.math3.util.MathUtils; + +/** + * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. + * @param <T> type of the points to cluster + * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> + * @since 3.2 + */ +public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { + + /** Strategies to use for replacing an empty cluster. */ + public enum EmptyClusterStrategy { + + /** Split the cluster with largest distance variance. */ + LARGEST_VARIANCE, + + /** Split the cluster with largest number of points. */ + LARGEST_POINTS_NUMBER, + + /** Create a cluster around the point farthest from its centroid. */ + FARTHEST_POINT, + + /** Generate an error. */ + ERROR + + } + + /** The number of clusters. */ + private final int k; + + /** The maximum number of iterations. */ + private final int maxIterations; + + /** Random generator for choosing initial centers. */ + private final RandomGenerator random; + + /** Selected strategy for empty clusters. */ + private final EmptyClusterStrategy emptyStrategy; + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + */ + public KMeansPlusPlusClusterer(final int k) { + this(k, -1); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * <p> + * The euclidean distance will be used as default distance measure. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations) { + this(k, maxIterations, new EuclideanDistance()); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { + this(k, maxIterations, measure, new JDKRandomGenerator()); + } + + /** Build a clusterer. + * <p> + * The default strategy for handling empty clusters that may appear during + * algorithm iterations is to split the cluster with largest distance variance. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random) { + this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); + } + + /** Build a clusterer. + * + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm for. + * If negative, no maximum will be used. + * @param measure the distance measure to use + * @param random random generator to use for choosing initial centers + * @param emptyStrategy strategy to use for handling empty clusters that + * may appear during algorithm iterations + */ + public KMeansPlusPlusClusterer(final int k, final int maxIterations, + final DistanceMeasure measure, + final RandomGenerator random, + final EmptyClusterStrategy emptyStrategy) { + super(measure); + this.k = k; + this.maxIterations = maxIterations; + this.random = random; + this.emptyStrategy = emptyStrategy; + } + + /** + * Return the number of clusters this instance will use. + * @return the number of clusters + */ + public int getK() { + return k; + } + + /** + * Returns the maximum number of iterations this instance will use. + * @return the maximum number of iterations, or -1 if no maximum is set + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Returns the random generator this instance will use. + * @return the random generator + */ + public RandomGenerator getRandomGenerator() { + return random; + } + + /** + * Returns the {@link EmptyClusterStrategy} used by this instance. + * @return the {@link EmptyClusterStrategy} + */ + public EmptyClusterStrategy getEmptyClusterStrategy() { + return emptyStrategy; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * {@link #emptyStrategy} is set to {@code ERROR} + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException { + + // sanity checks + MathUtils.checkNotNull(points); + + // number of clusters has to be smaller or equal the number of data points + if (points.size() < k) { + throw new NumberIsTooSmallException(points.size(), k, false); + } + + // create the initial clusters + List<CentroidCluster<T>> clusters = chooseInitialCenters(points); + + // create an array containing the latest assignment of a point to a cluster + // no need to initialize the array, as it will be filled with the first assignment + int[] assignments = new int[points.size()]; + assignPointsToClusters(clusters, points, assignments); + + // iterate through updating the centers until we're done + final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; + for (int count = 0; count < max; count++) { + boolean emptyCluster = false; + List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(); + for (final CentroidCluster<T> cluster : clusters) { + final Clusterable newCenter; + if (cluster.getPoints().isEmpty()) { + switch (emptyStrategy) { + case LARGEST_VARIANCE : + newCenter = getPointFromLargestVarianceCluster(clusters); + break; + case LARGEST_POINTS_NUMBER : + newCenter = getPointFromLargestNumberCluster(clusters); + break; + case FARTHEST_POINT : + newCenter = getFarthestPoint(clusters); + break; + default : + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + emptyCluster = true; + } else { + newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); + } + newClusters.add(new CentroidCluster<T>(newCenter)); + } + int changes = assignPointsToClusters(newClusters, points, assignments); + clusters = newClusters; + + // if there were no more changes in the point-to-cluster assignment + // and there are no empty clusters left, return the current clusters + if (changes == 0 && !emptyCluster) { + return clusters; + } + } + return clusters; + } + + /** + * Adds the given points to the closest {@link Cluster}. + * + * @param clusters the {@link Cluster}s to add the points to + * @param points the points to add to the given {@link Cluster}s + * @param assignments points assignments to clusters + * @return the number of points assigned to different clusters as the iteration before + */ + private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, + final Collection<T> points, + final int[] assignments) { + int assignedDifferently = 0; + int pointIndex = 0; + for (final T p : points) { + int clusterIndex = getNearestCluster(clusters, p); + if (clusterIndex != assignments[pointIndex]) { + assignedDifferently++; + } + + CentroidCluster<T> cluster = clusters.get(clusterIndex); + cluster.addPoint(p); + assignments[pointIndex++] = clusterIndex; + } + + return assignedDifferently; + } + + /** + * Use K-means++ to choose the initial centers. + * + * @param points the points to choose the initial centers from + * @return the initial centers + */ + private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { + + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. + final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>(); + + // Choose one center uniformly at random from among the data points. + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + + resultSet.add(new CentroidCluster<T>(firstPoint)); + + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = distance(firstPoint, pointList.get(i)); + minDistSquared[i] = d*d; + } + } + + while (resultSet.size() < k) { + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } + } + + // Add one new data point as a center. Each point x is chosen with + // probability proportional to D(x)2 + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } + } + } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new CentroidCluster<T> (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = distance(p, pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } + } + + return resultSet; + } + + /** + * Get a random point from the {@link Cluster} with the largest distance variance. + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) + throws ConvergenceException { + + double maxVariance = Double.NEGATIVE_INFINITY; + Cluster<T> selected = null; + for (final CentroidCluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final Clusterable center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + final double variance = stat.getResult(); + + // select the cluster with the largest variance + if (variance > maxVariance) { + maxVariance = variance; + selected = cluster; + } + + } + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get a random point from the {@link Cluster} with the largest number of points + * + * @param clusters the {@link Cluster}s to search + * @return a random point from the selected cluster + * @throws ConvergenceException if clusters are all empty + */ + private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) + throws ConvergenceException { + + int maxNumber = 0; + Cluster<T> selected = null; + for (final Cluster<T> cluster : clusters) { + + // get the number of points of the current cluster + final int number = cluster.getPoints().size(); + + // select the cluster with the largest number of points + if (number > maxNumber) { + maxNumber = number; + selected = cluster; + } + + } + + // did we find at least one non-empty cluster ? + if (selected == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + // extract a random point from the cluster + final List<T> selectedPoints = selected.getPoints(); + return selectedPoints.remove(random.nextInt(selectedPoints.size())); + + } + + /** + * Get the point farthest to its cluster center + * + * @param clusters the {@link Cluster}s to search + * @return point farthest to its cluster center + * @throws ConvergenceException if clusters are all empty + */ + private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException { + + double maxDistance = Double.NEGATIVE_INFINITY; + Cluster<T> selectedCluster = null; + int selectedPoint = -1; + for (final CentroidCluster<T> cluster : clusters) { + + // get the farthest point + final Clusterable center = cluster.getCenter(); + final List<T> points = cluster.getPoints(); + for (int i = 0; i < points.size(); ++i) { + final double distance = distance(points.get(i), center); + if (distance > maxDistance) { + maxDistance = distance; + selectedCluster = cluster; + selectedPoint = i; + } + } + + } + + // did we find at least one non-empty cluster ? + if (selectedCluster == null) { + throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); + } + + return selectedCluster.getPoints().remove(selectedPoint); + + } + + /** + * Returns the nearest {@link Cluster} to the given point + * + * @param clusters the {@link Cluster}s to search + * @param point the point to find the nearest {@link Cluster} for + * @return the index of the nearest {@link Cluster} to the given point + */ + private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) { + double minDistance = Double.MAX_VALUE; + int clusterIndex = 0; + int minCluster = 0; + for (final CentroidCluster<T> c : clusters) { + final double distance = distance(point, c.getCenter()); + if (distance < minDistance) { + minDistance = distance; + minCluster = clusterIndex; + } + clusterIndex++; + } + return minCluster; + } + + /** + * Computes the centroid for a set of points. + * + * @param points the set of points + * @param dimension the point dimension + * @return the computed centroid for the set of points + */ + private Clusterable centroidOf(final Collection<T> points, final int dimension) { + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java new file mode 100644 index 0000000..796fc7a --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/MultiKMeansPlusPlusClusterer.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering; + +import java.util.Collection; +import java.util.List; + +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.exception.MathIllegalArgumentException; +import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator; +import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances; + +/** + * A wrapper around a k-means++ clustering algorithm which performs multiple trials + * and returns the best solution. + * @param <T> type of the points to cluster + * @since 3.2 + */ +public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { + + /** The underlying k-means clusterer. */ + private final KMeansPlusPlusClusterer<T> clusterer; + + /** The number of trial runs. */ + private final int numTrials; + + /** The cluster evaluator to use. */ + private final ClusterEvaluator<T> evaluator; + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, + final int numTrials) { + this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())); + } + + /** Build a clusterer. + * @param clusterer the k-means clusterer to use + * @param numTrials number of trial runs + * @param evaluator the cluster evaluator to use + * @since 3.3 + */ + public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, + final int numTrials, + final ClusterEvaluator<T> evaluator) { + super(clusterer.getDistanceMeasure()); + this.clusterer = clusterer; + this.numTrials = numTrials; + this.evaluator = evaluator; + } + + /** + * Returns the embedded k-means clusterer used by this instance. + * @return the embedded clusterer + */ + public KMeansPlusPlusClusterer<T> getClusterer() { + return clusterer; + } + + /** + * Returns the number of trials this instance will do. + * @return the number of trials + */ + public int getNumTrials() { + return numTrials; + } + + /** + * Returns the {@link ClusterEvaluator} used to determine the "best" clustering. + * @return the used {@link ClusterEvaluator} + * @since 3.3 + */ + public ClusterEvaluator<T> getClusterEvaluator() { + return evaluator; + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + * @throws ConvergenceException if an empty cluster is encountered and the + * underlying {@link KMeansPlusPlusClusterer} has its + * {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}. + */ + @Override + public List<CentroidCluster<T>> cluster(final Collection<T> points) + throws MathIllegalArgumentException, ConvergenceException { + + // at first, we have not found any clusters list yet + List<CentroidCluster<T>> best = null; + double bestVarianceSum = Double.POSITIVE_INFINITY; + + // do several clustering trials + for (int i = 0; i < numTrials; ++i) { + + // compute a clusters list + List<CentroidCluster<T>> clusters = clusterer.cluster(points); + + // compute the variance of the current list + final double varianceSum = evaluator.score(clusters); + + if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) { + // this one is the best we have found so far, remember it + best = clusters; + bestVarianceSum = varianceSum; + } + + } + + // return the best clusters list found + return best; + + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java new file mode 100644 index 0000000..2bb8ba3 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/ClusterEvaluator.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.CentroidCluster; +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.clustering.DoublePoint; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; + +/** + * Base class for cluster evaluation methods. + * + * @param <T> type of the clustered points + * @since 3.3 + */ +public abstract class ClusterEvaluator<T extends Clusterable> { + + /** The distance measure to use when evaluating the cluster. */ + private final DistanceMeasure measure; + + /** + * Creates a new cluster evaluator with an {@link EuclideanDistance} + * as distance measure. + */ + public ClusterEvaluator() { + this(new EuclideanDistance()); + } + + /** + * Creates a new cluster evaluator with the given distance measure. + * @param measure the distance measure to use + */ + public ClusterEvaluator(final DistanceMeasure measure) { + this.measure = measure; + } + + /** + * Computes the evaluation score for the given list of clusters. + * @param clusters the clusters to evaluate + * @return the computed score + */ + public abstract double score(List<? extends Cluster<T>> clusters); + + /** + * Returns whether the first evaluation score is considered to be better + * than the second one by this evaluator. + * <p> + * Specific implementations shall override this method if the returned scores + * do not follow the same ordering, i.e. smaller score is better. + * + * @param score1 the first score + * @param score2 the second score + * @return {@code true} if the first score is considered to be better, {@code false} otherwise + */ + public boolean isBetterScore(double score1, double score2) { + return score1 < score2; + } + + /** + * Calculates the distance between two {@link Clusterable} instances + * with the configured {@link DistanceMeasure}. + * + * @param p1 the first clusterable + * @param p2 the second clusterable + * @return the distance between the two clusterables + */ + protected double distance(final Clusterable p1, final Clusterable p2) { + return measure.compute(p1.getPoint(), p2.getPoint()); + } + + /** + * Computes the centroid for a cluster. + * + * @param cluster the cluster + * @return the computed centroid for the cluster, + * or {@code null} if the cluster does not contain any points + */ + protected Clusterable centroidOf(final Cluster<T> cluster) { + final List<T> points = cluster.getPoints(); + if (points.isEmpty()) { + return null; + } + + // in case the cluster is of type CentroidCluster, no need to compute the centroid + if (cluster instanceof CentroidCluster) { + return ((CentroidCluster<T>) cluster).getCenter(); + } + + final int dimension = points.get(0).getPoint().length; + final double[] centroid = new double[dimension]; + for (final T p : points) { + final double[] point = p.getPoint(); + for (int i = 0; i < centroid.length; i++) { + centroid[i] += point[i]; + } + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= points.size(); + } + return new DoublePoint(centroid); + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java new file mode 100644 index 0000000..b5b249c --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/SumOfClusterVariances.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math3.ml.clustering.evaluation; + +import java.util.List; + +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.commons.math3.ml.clustering.Clusterable; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.stat.descriptive.moment.Variance; + +/** + * Computes the sum of intra-cluster distance variances according to the formula: + * <pre> + * \( score = \sum\limits_{i=1}^n \sigma_i^2 \) + * </pre> + * where n is the number of clusters and \( \sigma_i^2 \) is the variance of + * intra-cluster distances of cluster \( c_i \). + * + * @param <T> the type of the clustered points + * @since 3.3 + */ +public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T> { + + /** + * + * @param measure the distance measure to use + */ + public SumOfClusterVariances(final DistanceMeasure measure) { + super(measure); + } + + /** {@inheritDoc} */ + @Override + public double score(final List<? extends Cluster<T>> clusters) { + double varianceSum = 0.0; + for (final Cluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + final Clusterable center = centroidOf(cluster); + + // compute the distance variance of the current cluster + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(distance(point, center)); + } + varianceSum += stat.getResult(); + + } + } + return varianceSum; + } + +} diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java new file mode 100644 index 0000000..700f566 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/evaluation/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Cluster evaluation methods. + */ +package org.apache.commons.math3.ml.clustering.evaluation; diff --git a/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java new file mode 100644 index 0000000..02f1d20 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/ml/clustering/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Clustering algorithms. + */ +package org.apache.commons.math3.ml.clustering; |