summaryrefslogtreecommitdiff
path: root/src/main/java/org/apache/commons/math/stat/clustering
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/org/apache/commons/math/stat/clustering')
-rw-r--r--src/main/java/org/apache/commons/math/stat/clustering/Cluster.java74
-rw-r--r--src/main/java/org/apache/commons/math/stat/clustering/Clusterable.java46
-rw-r--r--src/main/java/org/apache/commons/math/stat/clustering/EuclideanIntegerPoint.java120
-rw-r--r--src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java333
-rw-r--r--src/main/java/org/apache/commons/math/stat/clustering/package.html20
5 files changed, 593 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math/stat/clustering/Cluster.java b/src/main/java/org/apache/commons/math/stat/clustering/Cluster.java
new file mode 100644
index 0000000..f4913d3
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/clustering/Cluster.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math.stat.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Cluster holding a set of {@link Clusterable} points.
+ * @param <T> the type of points that can be clustered
+ * @version $Revision: 771076 $ $Date: 2009-05-03 18:28:48 +0200 (dim. 03 mai 2009) $
+ * @since 2.0
+ */
+public class Cluster<T extends Clusterable<T>> implements Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = -3442297081515880464L;
+
+ /** The points contained in this cluster. */
+ private final List<T> points;
+
+ /** Center of the cluster. */
+ private final T center;
+
+ /**
+ * Build a cluster centered at a specified point.
+ * @param center the point which is to be the center of this cluster
+ */
+ public Cluster(final T center) {
+ this.center = center;
+ points = new ArrayList<T>();
+ }
+
+ /**
+ * Add a point to this cluster.
+ * @param point point to add
+ */
+ public void addPoint(final T point) {
+ points.add(point);
+ }
+
+ /**
+ * Get the points contained in the cluster.
+ * @return points contained in the cluster
+ */
+ public List<T> getPoints() {
+ return points;
+ }
+
+ /**
+ * Get the point chosen to be the center of this cluster.
+ * @return chosen cluster center
+ */
+ public T getCenter() {
+ return center;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math/stat/clustering/Clusterable.java b/src/main/java/org/apache/commons/math/stat/clustering/Clusterable.java
new file mode 100644
index 0000000..65132e6
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/clustering/Clusterable.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math.stat.clustering;
+
+import java.util.Collection;
+
+/**
+ * Interface for points that can be clustered together.
+ * @param <T> the type of point that can be clustered
+ * @version $Revision: 811685 $ $Date: 2009-09-05 19:36:48 +0200 (sam. 05 sept. 2009) $
+ * @since 2.0
+ */
+public interface Clusterable<T> {
+
+ /**
+ * Returns the distance from the given point.
+ *
+ * @param p the point to compute the distance from
+ * @return the distance from the given point
+ */
+ double distanceFrom(T p);
+
+ /**
+ * Returns the centroid of the given Collection of points.
+ *
+ * @param p the Collection of points to compute the centroid of
+ * @return the centroid of the given Collection of Points
+ */
+ T centroidOf(Collection<T> p);
+
+}
diff --git a/src/main/java/org/apache/commons/math/stat/clustering/EuclideanIntegerPoint.java b/src/main/java/org/apache/commons/math/stat/clustering/EuclideanIntegerPoint.java
new file mode 100644
index 0000000..7fec0ff
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/clustering/EuclideanIntegerPoint.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math.stat.clustering;
+
+import java.io.Serializable;
+import java.util.Collection;
+
+import org.apache.commons.math.util.MathUtils;
+
+/**
+ * A simple implementation of {@link Clusterable} for points with integer coordinates.
+ * @version $Revision: 1042376 $ $Date: 2010-12-05 16:54:55 +0100 (dim. 05 déc. 2010) $
+ * @since 2.0
+ */
+public class EuclideanIntegerPoint implements Clusterable<EuclideanIntegerPoint>, Serializable {
+
+ /** Serializable version identifier. */
+ private static final long serialVersionUID = 3946024775784901369L;
+
+ /** Point coordinates. */
+ private final int[] point;
+
+ /**
+ * Build an instance wrapping an integer array.
+ * <p>The wrapped array is referenced, it is <em>not</em> copied.</p>
+ * @param point the n-dimensional point in integer space
+ */
+ public EuclideanIntegerPoint(final int[] point) {
+ this.point = point;
+ }
+
+ /**
+ * Get the n-dimensional point in integer space.
+ * @return a reference (not a copy!) to the wrapped array
+ */
+ public int[] getPoint() {
+ return point;
+ }
+
+ /** {@inheritDoc} */
+ public double distanceFrom(final EuclideanIntegerPoint p) {
+ return MathUtils.distance(point, p.getPoint());
+ }
+
+ /** {@inheritDoc} */
+ public EuclideanIntegerPoint centroidOf(final Collection<EuclideanIntegerPoint> points) {
+ int[] centroid = new int[getPoint().length];
+ for (EuclideanIntegerPoint p : points) {
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] += p.getPoint()[i];
+ }
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= points.size();
+ }
+ return new EuclideanIntegerPoint(centroid);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(final Object other) {
+ if (!(other instanceof EuclideanIntegerPoint)) {
+ return false;
+ }
+ final int[] otherPoint = ((EuclideanIntegerPoint) other).getPoint();
+ if (point.length != otherPoint.length) {
+ return false;
+ }
+ for (int i = 0; i < point.length; i++) {
+ if (point[i] != otherPoint[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ int hashCode = 0;
+ for (Integer i : point) {
+ hashCode += i.hashCode() * 13 + 7;
+ }
+ return hashCode;
+ }
+
+ /**
+ * {@inheritDoc}
+ * @since 2.1
+ */
+ @Override
+ public String toString() {
+ final StringBuilder buff = new StringBuilder("(");
+ final int[] coordinates = getPoint();
+ for (int i = 0; i < coordinates.length; i++) {
+ buff.append(coordinates[i]);
+ if (i < coordinates.length - 1) {
+ buff.append(",");
+ }
+ }
+ buff.append(")");
+ return buff.toString();
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
new file mode 100644
index 0000000..eb61866
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math.stat.clustering;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.math.exception.ConvergenceException;
+import org.apache.commons.math.exception.util.LocalizedFormats;
+import org.apache.commons.math.stat.descriptive.moment.Variance;
+
+/**
+ * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
+ * @param <T> type of the points to cluster
+ * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
+ * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $
+ * @since 2.0
+ */
+public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
+
+ /** Strategies to use for replacing an empty cluster. */
+ public static enum EmptyClusterStrategy {
+
+ /** Split the cluster with largest distance variance. */
+ LARGEST_VARIANCE,
+
+ /** Split the cluster with largest number of points. */
+ LARGEST_POINTS_NUMBER,
+
+ /** Create a cluster around the point farthest from its centroid. */
+ FARTHEST_POINT,
+
+ /** Generate an error. */
+ ERROR
+
+ }
+
+ /** Random generator for choosing initial centers. */
+ private final Random random;
+
+ /** Selected strategy for empty clusters. */
+ private final EmptyClusterStrategy emptyStrategy;
+
+ /** Build a clusterer.
+ * <p>
+ * The default strategy for handling empty clusters that may appear during
+ * algorithm iterations is to split the cluster with largest distance variance.
+ * </p>
+ * @param random random generator to use for choosing initial centers
+ */
+ public KMeansPlusPlusClusterer(final Random random) {
+ this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
+ }
+
+ /** Build a clusterer.
+ * @param random random generator to use for choosing initial centers
+ * @param emptyStrategy strategy to use for handling empty clusters that
+ * may appear during algorithm iterations
+ * @since 2.2
+ */
+ public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
+ this.random = random;
+ this.emptyStrategy = emptyStrategy;
+ }
+
+ /**
+ * Runs the K-means++ clustering algorithm.
+ *
+ * @param points the points to cluster
+ * @param k the number of clusters to split the data into
+ * @param maxIterations the maximum number of iterations to run the algorithm
+ * for. If negative, no maximum will be used
+ * @return a list of clusters containing the points
+ */
+ public List<Cluster<T>> cluster(final Collection<T> points,
+ final int k, final int maxIterations) {
+ // create the initial clusters
+ List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
+ assignPointsToClusters(clusters, points);
+
+ // iterate through updating the centers until we're done
+ final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
+ for (int count = 0; count < max; count++) {
+ boolean clusteringChanged = false;
+ List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
+ for (final Cluster<T> cluster : clusters) {
+ final T newCenter;
+ if (cluster.getPoints().isEmpty()) {
+ switch (emptyStrategy) {
+ case LARGEST_VARIANCE :
+ newCenter = getPointFromLargestVarianceCluster(clusters);
+ break;
+ case LARGEST_POINTS_NUMBER :
+ newCenter = getPointFromLargestNumberCluster(clusters);
+ break;
+ case FARTHEST_POINT :
+ newCenter = getFarthestPoint(clusters);
+ break;
+ default :
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+ clusteringChanged = true;
+ } else {
+ newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
+ if (!newCenter.equals(cluster.getCenter())) {
+ clusteringChanged = true;
+ }
+ }
+ newClusters.add(new Cluster<T>(newCenter));
+ }
+ if (!clusteringChanged) {
+ return clusters;
+ }
+ assignPointsToClusters(newClusters, points);
+ clusters = newClusters;
+ }
+ return clusters;
+ }
+
+ /**
+ * Adds the given points to the closest {@link Cluster}.
+ *
+ * @param <T> type of the points to cluster
+ * @param clusters the {@link Cluster}s to add the points to
+ * @param points the points to add to the given {@link Cluster}s
+ */
+ private static <T extends Clusterable<T>> void
+ assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
+ for (final T p : points) {
+ Cluster<T> cluster = getNearestCluster(clusters, p);
+ cluster.addPoint(p);
+ }
+ }
+
+ /**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @param <T> type of the points to cluster
+ * @param points the points to choose the initial centers from
+ * @param k the number of centers to choose
+ * @param random random generator to use
+ * @return the initial centers
+ */
+ private static <T extends Clusterable<T>> List<Cluster<T>>
+ chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
+
+ final List<T> pointSet = new ArrayList<T>(points);
+ final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
+
+ // Choose one center uniformly at random from among the data points.
+ final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
+ resultSet.add(new Cluster<T>(firstPoint));
+
+ final double[] dx2 = new double[pointSet.size()];
+ while (resultSet.size() < k) {
+ // For each data point x, compute D(x), the distance between x and
+ // the nearest center that has already been chosen.
+ int sum = 0;
+ for (int i = 0; i < pointSet.size(); i++) {
+ final T p = pointSet.get(i);
+ final Cluster<T> nearest = getNearestCluster(resultSet, p);
+ final double d = p.distanceFrom(nearest.getCenter());
+ sum += d * d;
+ dx2[i] = sum;
+ }
+
+ // Add one new data point as a center. Each point x is chosen with
+ // probability proportional to D(x)2
+ final double r = random.nextDouble() * sum;
+ for (int i = 0 ; i < dx2.length; i++) {
+ if (dx2[i] >= r) {
+ final T p = pointSet.remove(i);
+ resultSet.add(new Cluster<T>(p));
+ break;
+ }
+ }
+ }
+
+ return resultSet;
+
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest distance variance.
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ */
+ private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
+
+ double maxVariance = Double.NEGATIVE_INFINITY;
+ Cluster<T> selected = null;
+ for (final Cluster<T> cluster : clusters) {
+ if (!cluster.getPoints().isEmpty()) {
+
+ // compute the distance variance of the current cluster
+ final T center = cluster.getCenter();
+ final Variance stat = new Variance();
+ for (final T point : cluster.getPoints()) {
+ stat.increment(point.distanceFrom(center));
+ }
+ final double variance = stat.getResult();
+
+ // select the cluster with the largest variance
+ if (variance > maxVariance) {
+ maxVariance = variance;
+ selected = cluster;
+ }
+
+ }
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List<T> selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get a random point from the {@link Cluster} with the largest number of points
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return a random point from the selected cluster
+ */
+ private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
+
+ int maxNumber = 0;
+ Cluster<T> selected = null;
+ for (final Cluster<T> cluster : clusters) {
+
+ // get the number of points of the current cluster
+ final int number = cluster.getPoints().size();
+
+ // select the cluster with the largest number of points
+ if (number > maxNumber) {
+ maxNumber = number;
+ selected = cluster;
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selected == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ // extract a random point from the cluster
+ final List<T> selectedPoints = selected.getPoints();
+ return selectedPoints.remove(random.nextInt(selectedPoints.size()));
+
+ }
+
+ /**
+ * Get the point farthest to its cluster center
+ *
+ * @param clusters the {@link Cluster}s to search
+ * @return point farthest to its cluster center
+ */
+ private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
+
+ double maxDistance = Double.NEGATIVE_INFINITY;
+ Cluster<T> selectedCluster = null;
+ int selectedPoint = -1;
+ for (final Cluster<T> cluster : clusters) {
+
+ // get the farthest point
+ final T center = cluster.getCenter();
+ final List<T> points = cluster.getPoints();
+ for (int i = 0; i < points.size(); ++i) {
+ final double distance = points.get(i).distanceFrom(center);
+ if (distance > maxDistance) {
+ maxDistance = distance;
+ selectedCluster = cluster;
+ selectedPoint = i;
+ }
+ }
+
+ }
+
+ // did we find at least one non-empty cluster ?
+ if (selectedCluster == null) {
+ throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
+ }
+
+ return selectedCluster.getPoints().remove(selectedPoint);
+
+ }
+
+ /**
+ * Returns the nearest {@link Cluster} to the given point
+ *
+ * @param <T> type of the points to cluster
+ * @param clusters the {@link Cluster}s to search
+ * @param point the point to find the nearest {@link Cluster} for
+ * @return the nearest {@link Cluster} to the given point
+ */
+ private static <T extends Clusterable<T>> Cluster<T>
+ getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
+ double minDistance = Double.MAX_VALUE;
+ Cluster<T> minCluster = null;
+ for (final Cluster<T> c : clusters) {
+ final double distance = point.distanceFrom(c.getCenter());
+ if (distance < minDistance) {
+ minDistance = distance;
+ minCluster = c;
+ }
+ }
+ return minCluster;
+ }
+
+}
diff --git a/src/main/java/org/apache/commons/math/stat/clustering/package.html b/src/main/java/org/apache/commons/math/stat/clustering/package.html
new file mode 100644
index 0000000..21e9079
--- /dev/null
+++ b/src/main/java/org/apache/commons/math/stat/clustering/package.html
@@ -0,0 +1,20 @@
+<html>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+ <!-- $Revision: 770979 $ $Date: 2009-05-02 21:34:51 +0200 (sam. 02 mai 2009) $ -->
+ <body>Clustering algorithms</body>
+</html>