2022-04-13: K-Means Example

Clustering some data using the KMeans++ algorithm in Apache Math.

Note, use the KMeansPlusPlusClusterer and Clusterable definitions from "org.apache.commons.math3.ml.clustering", not the deprecated ones in "org.apache.commons.math3.stat.clustering".

Example: Iris Dataset

The example used here is the Iris Data set from the UCI database: download iris.data.

This note adapts the technique in Reading CSV Files.

There is an additional requirement for the data used within the KMeans++ implementation: each record should implement Clusterable. This interface requires a single method to be added to the record, getPoint, which returns the instance's data as an array of doubles.

Java's records do not permit the addition of new instance fields, so the IrisInstance record is extended to include an attribute for the required point provided on creation: the fromCSVRecord method is extended to construct the point representation.

  record IrisInstance(
      double sepalLength,
      double sepalWidth,
      double petalLength,
      double petalWidth,
      double[] point,                                                             // <1>
      String label
      ) implements Clusterable {                                                  // <2>

    public double[] getPoint () {                                                 // <3>
      return point;
    }
    
    public static Optional<IrisInstance> fromCSVRecord (CSVRecord record) {
      if (record.size() == 5) { // ensure we have enough parts 
        try {
          double sepalLength = Double.parseDouble(record.get(0));
          double sepalWidth = Double.parseDouble(record.get(1));
          double petalLength = Double.parseDouble(record.get(2));
          double petalWidth = Double.parseDouble(record.get(3));

          return Optional.of(new IrisInstance(sepalLength,
                sepalWidth, petalLength, petalWidth, 
                new double[]{sepalLength, sepalWidth, petalLength, petalWidth},   // <4>
                record.get(4)));

        } catch (NumberFormatException e) {
          ;
        }
      }

      return Optional.empty();
    }
  }
  1. The additional attribute, for the point representation.
  2. Indicate that the record extends the Clusterable interface.
  3. getPoint is required by the Clusterable interface, and simply returns the stored point.
  4. When constructing the new record instance, create the point representation as well.

List<IrisInstance> data is then loaded from the CSV file in the same way as described before in Reading CSV Files.

KMeans++ in Apache Commons Math

The Apache Commons Math library provides the class KMeansPlusPlusClusterer. Properties of the model are defined on constructing an instance of this class, including:

As the Iris dataset has three class labels, this example uses 3 as the expected number of classes, leaving the other properties to their default values.

On calling the cluster method with a collection of Clusterable instances, the model returns a list of CentroidCluster instances. Each centroid cluster has a "centre", and the list of points assigned to that cluster.

    var model = new KMeansPlusPlusClusterer<IrisInstance>(3);                             // <1>
    var clusters = model.cluster (data);                                                  // <2>

    for (var cluster : clusters) {
      System.out.println ("Centre: " + Arrays.toString (cluster.getCenter().getPoint())); // <3>
      System.out.println ("Cluster has: " + cluster.getPoints().size() + " points");      // <4>
    }
  1. Create a model, providing a target number of clusters.
  2. Use cluster to group the given data into a list of cluster definitions.
  3. Extract and display the centre of each cluster as a point.
  4. The points assigned to each cluster can be retrieved using getPoints.

Example output (the clustering is stochastic, so different runs can produce different values):

> java .\KMeansExample.java
Centre: [5.005999999999999, 3.4180000000000006, 1.464, 0.2439999999999999]
Cluster has: 50 points
Centre: [5.88360655737705, 2.740983606557377, 4.388524590163935, 1.4344262295081966]
Cluster has: 61 points
Centre: [6.853846153846153, 3.0769230769230766, 5.715384615384615, 2.053846153846153]
Cluster has: 39 points

Page from Peter's Scrapbook, output from a VimWiki on 2024-01-29.