KMeansClusterer.java

/*
 * Copyright (c) 2017, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
 *
 * WSO2 Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.wso2.extension.siddhi.execution.streamingml.clustering.kmeans.util;

import org.apache.log4j.Logger;
import org.wso2.extension.siddhi.execution.streamingml.util.MathUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

/**
 * class containing all mathematical logic needed to perform kmeans
 */
public class KMeansClusterer {
    private static final Logger logger = Logger.getLogger(KMeansClusterer.class.getName());

    public static void train(LinkedList<DataPoint> dataPointsArray, int numberOfEventsToRetrain, double decayRate,
                      ExecutorService executorService, KMeansModel model, int numberOfClusters,
                      int maximumIterations, int dimensionality) {
        if ((!model.isTrained())) {
            cluster(dataPointsArray, model, numberOfClusters, maximumIterations, dimensionality);
            dataPointsArray.clear();
            model.setTrained();
        } else {
            periodicTraining(numberOfEventsToRetrain, decayRate, executorService, dataPointsArray, model,
                    numberOfClusters, maximumIterations, dimensionality);
        }
    }

    private static void periodicTraining(int numberOfEventsToRetrain, double decayRate, ExecutorService executorService,
                                  LinkedList<DataPoint> dataPointsArray, KMeansModel model, int numberOfClusters,
                                  int maximumIterations, int dimensionality) {
        int minBatchSizeToTriggerSeparateThread = 20; //TODO: test and tune to optimum value
        if (numberOfEventsToRetrain < minBatchSizeToTriggerSeparateThread) {
            if (logger.isDebugEnabled()) {
                logger.debug("Traditional training");
            }
            updateCluster(dataPointsArray, decayRate, model, numberOfClusters, maximumIterations, dimensionality);
            dataPointsArray.clear();
        } else {
            if (logger.isDebugEnabled()) {
                logger.debug("Seperate thread training");
            }
            Trainer trainer = new Trainer(dataPointsArray, decayRate, model, numberOfClusters, maximumIterations,
                    dimensionality);
            Future f = executorService.submit(trainer);
        }
    }


    /**
     * Perform clustering
     */
    private static void cluster(List<DataPoint> dataPointsArray, KMeansModel model, int numberOfClusters,
                        int maximumIterations, int dimensionality) {
        if (logger.isDebugEnabled()) {
            logger.debug("initial Clustering");
        }
        buildModel(dataPointsArray, model, numberOfClusters);

        int iter = 0;
        if (dataPointsArray.size() != 0 && (model.size() == numberOfClusters)) {
            boolean centroidShifted;
            while (iter < maximumIterations) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Current model : \n" + model.getModelInfo() + "\nclustering iteration : " + iter);
                }
                assignToCluster(dataPointsArray, model);
                if (logger.isDebugEnabled()) {
                    logger.debug("Current model : \n" + model.getModelInfo());
                }
                List<Cluster> newClusterList = calculateNewClusters(model, dimensionality);

                centroidShifted = !model.getClusterList().equals(newClusterList);
                if (logger.isDebugEnabled()) {
                    logger.debug("previous model : " + printClusterList(model.getClusterList()) + "\nnew model : " +
                            printClusterList(newClusterList) + "\ncentroid shifted?" + centroidShifted);
                }
                if (!centroidShifted) {
                    break;
                }
                model.setClusterList(newClusterList);
                iter++;
            }
        }
    }

    private static String printClusterList(List<Cluster> clusterList) {
        StringBuilder s = new StringBuilder();
        for (Cluster c: clusterList) {
            s.append(Arrays.toString(c.getCentroid().getCoordinates()));
        }
        return s.toString();
    }

    private static void buildModel(List<DataPoint> dataPointsArray, KMeansModel model, int numberOfClusters) {
        int distinctCount = model.size();
        for (DataPoint currentDataPoint : dataPointsArray) {
            if (distinctCount >= numberOfClusters) {
                break;
            }
            DataPoint coordinatesOfCurrentDataPoint = new DataPoint();
            coordinatesOfCurrentDataPoint.setCoordinates(currentDataPoint.getCoordinates());
            if (!model.contains(coordinatesOfCurrentDataPoint)) {
                model.add(coordinatesOfCurrentDataPoint);
                distinctCount++;
            }
        }
    }

    /**
     * After the first clustering this method can be used to incrementally update centroidList
     * in real time. This method takes in the new set of datapoints and decayRate as inputs
     * and calculates the centroids of the new set. Then new centroids are calculated using
     * newAvg = oldAvg + decayRate * batchAvg
     */
    static void updateCluster(List<DataPoint> dataPointsArray, double decayRate, KMeansModel model,
                              int numberOfClusters, int maximumIterations, int dimensionality) {
        if (logger.isDebugEnabled()) {
            logger.debug("Updating cluster");
            logger.debug("model at the start of this update : ");
            logger.debug(model.getModelInfo());
        }
        StringBuilder s;
        List<Cluster> intermediateClusterList = new LinkedList<>();

        int iter = 0;
        if (dataPointsArray.size() != 0) {
            //when number of elements in centroid list is less than numberOfClusters
            if (model.size() < numberOfClusters) {
                buildModel(dataPointsArray, model, numberOfClusters);
            }
            if (model.size() == numberOfClusters) {
                ArrayList<Cluster> oldClusterList = new ArrayList<>(numberOfClusters);
                for (int i = 0; i < numberOfClusters; i++) {
                    DataPoint d = new DataPoint();
                    DataPoint d1 = new DataPoint();
                    d.setCoordinates(model.getCoordinatesOfCentroidOfCluster(i));
                    d1.setCoordinates(model.getCoordinatesOfCentroidOfCluster(i));
                    Cluster c = new Cluster(d);;
                    Cluster c1 = new Cluster(d1);
                    oldClusterList.add(c);
                    intermediateClusterList.add(c1);
                }
                boolean centroidShifted = false;
                while (iter < maximumIterations) {

                    assignToCluster(dataPointsArray, model);
                    List<Cluster> newClusterList = calculateNewClusters(model, dimensionality);
                    centroidShifted = !intermediateClusterList.equals(newClusterList);
                    if (logger.isDebugEnabled()) {
                        s = new StringBuilder();
                        for (DataPoint c : dataPointsArray) {
                            s.append(Arrays.toString(c.getCoordinates()));
                        }
                        logger.debug("current iteration : " + iter + "\ndata points array\n" + s.toString() +
                                "\nCluster list : \n" + printClusterList(intermediateClusterList) +
                                "\nnew cluster list \n" + printClusterList(newClusterList) + "\nCentroid shifted? = "
                                + centroidShifted + "\n");
                    }
                    if (!centroidShifted) {
                        break;
                    }
                    model.setClusterList(newClusterList);
                    for (int i = 0; i < numberOfClusters; i++) {
                        Cluster b = newClusterList.get(i);
                        intermediateClusterList.get(i).getCentroid().setCoordinates(b.getCentroid().getCoordinates());
                    }
                    iter++;
                }
                if (logger.isDebugEnabled()) {
                    logger.debug("old cluster list :\n" + printClusterList(oldClusterList));
                }
                for (int i = 0; i < numberOfClusters; i++) {
                    if (model.getClusterList().get(i).getDataPointsInCluster().size() != 0) {
                        double[] weightedCoordinates = new double[dimensionality];
                        double[] oldCoordinates = oldClusterList.get(i).getCentroid().getCoordinates();
                        double[] newCoordinates = intermediateClusterList.get(i).getCentroid().getCoordinates();
                        for (int j = 0; j < dimensionality; j++) {
                            weightedCoordinates[j] = Math.round(((1 - decayRate) * oldCoordinates[j] + decayRate *
                                    newCoordinates[j]) * 10000.0) / 10000.0;
                        }
                        intermediateClusterList.get(i).getCentroid().setCoordinates(weightedCoordinates);
                    } else {
                        intermediateClusterList.get(i).getCentroid().setCoordinates(
                                oldClusterList.get(i).getCentroid().getCoordinates());
                    }
                }
                model.setClusterList(intermediateClusterList);
                if (logger.isDebugEnabled()) {
                    logger.debug("weighted centroid list\n" + printClusterList(model.getClusterList()));
                }
            }
        }
    }

    /**
     * finds the nearest centroid to each data point in the input array
     */
    private static void assignToCluster(List<DataPoint> dataPointsArray, KMeansModel model) {
        logger.debug("Running function assignToCluster");
        model.clearClusterMembers();
        for (DataPoint currentDataPoint : dataPointsArray) {
            Cluster associatedCluster = findAssociatedCluster(currentDataPoint, model);
            logger.debug("Associated cluster of " + Arrays.toString(currentDataPoint.getCoordinates()) + " is " +
             Arrays.toString(associatedCluster.getCentroid().getCoordinates()));
            associatedCluster.addToCluster(currentDataPoint);
        }
    }

    /**
     * finds the nearest centroid to a given DataPoint
     * @return centroid - the nearest centroid to the input DataPoint
     */
    private static Cluster findAssociatedCluster(DataPoint currentDatapoint, KMeansModel model) {
        double minDistance = MathUtil.euclideanDistance(model.getCoordinatesOfCentroidOfCluster(0),
                currentDatapoint.getCoordinates());
        Cluster associatedCluster = model.getClusterList().get(0);
        for (int i = 0; i < model.size(); i++) {
            Cluster cluster = model.getClusterList().get(i);
            double dist = MathUtil.euclideanDistance(cluster.getCentroid().getCoordinates(),
                    currentDatapoint.getCoordinates());
            if (dist < minDistance) {
                minDistance = dist;
                associatedCluster = cluster;
            }
        }
        return associatedCluster;
    }

    /**
     * similar to findAssociatedCluster method but return an Object[] array with the distance
     * to closest centroid and the coordinates of the closest centroid
     *
     * @return an Object[] array as mentioned above
     */
    public static Object[] getAssociatedCentroidInfo(DataPoint currentDatapoint, KMeansModel model) {
        Cluster associatedCluster = findAssociatedCluster(currentDatapoint, model);
        double minDistance = MathUtil.euclideanDistance(currentDatapoint.getCoordinates(),
                associatedCluster.getCentroid().getCoordinates());
        List<Double> associatedCentroidInfoList = new ArrayList<Double>();
        associatedCentroidInfoList.add(minDistance);

        for (double x : associatedCluster.getCentroid().getCoordinates()) {
            associatedCentroidInfoList.add(x);
        }

        Object[] associatedCentroidInfo = new Object[associatedCentroidInfoList.size()];
        associatedCentroidInfoList.toArray(associatedCentroidInfo);
        return associatedCentroidInfo;
    }

    /**
     * after assigning data points to closest centroids this method calculates new centroids using
     * the assigned points
     *
     * @return returns an array list of coordinate objects each representing a centroid
     */
    private static List<Cluster> calculateNewClusters(KMeansModel model, int dimensionality) {
        List<Cluster> newClusterList = new LinkedList<>();

        for (Cluster c: model.getClusterList()) {
            double[] total;
            total = new double[dimensionality];
            for (DataPoint d: c.getDataPointsInCluster()) {
                double[] coordinatesOfd = d.getCoordinates();
                for (int i = 0; i < dimensionality; i++) {
                    total[i] += coordinatesOfd[i];
                }
            }
            int numberOfMembers = c.getDataPointsInCluster().size();
            for (int i = 0; i < dimensionality; i++) {
                total[i] = Math.round((total[i] / numberOfMembers) * 10000.0) / 10000.0;
            }

            DataPoint d1 = new DataPoint();
            d1.setCoordinates(total);
            Cluster c1 = new Cluster(d1);
            newClusterList.add(c1);
        }
        return newClusterList;
    }
}