/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.cluster;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.cluster.PointClassification;
import org.deeplearning4j.clustering.info.ClusterInfo;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.util.MathUtils;
import org.deeplearning4j.clustering.util.MultiThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClusterUtils {
    private static final Logger log = LoggerFactory.getLogger(ClusterUtils.class);

    public static ClusterSetInfo classifyPoints(ClusterSet clusterSet, List<Point> points, ExecutorService executorService) {
        ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
        ArrayList tasks = new ArrayList();
        for (Point point : points) {
            try {
                PointClassification result = ClusterUtils.classifyPoint(clusterSet, point);
                if (result.isNewLocation()) {
                    clusterSetInfo.getPointLocationChange().incrementAndGet();
                }
                clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter().put(point.getId(), result.getDistanceFromCenter());
            }
            catch (Throwable t) {
                log.warn("Error classifying point", t);
            }
        }
        return clusterSetInfo;
    }

    public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
        return clusterSet.classifyPoint(point, false);
    }

    public static void refreshClustersCenters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, ExecutorService executorService) {
        ArrayList tasks = new ArrayList();
        int nClusters = clusterSet.getClusterCount();
        for (int i = 0; i < nClusters; ++i) {
            Cluster cluster = clusterSet.getClusters().get(i);
            try {
                ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
                ClusterUtils.refreshClusterCenter(cluster, clusterInfo);
                ClusterUtils.deriveClusterInfoDistanceStatistics(clusterInfo);
                continue;
            }
            catch (Throwable t) {
                log.warn("Error refreshing cluster centers", t);
            }
        }
    }

    public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
        int pointsCount = cluster.getPoints().size();
        if (pointsCount == 0) {
            return;
        }
        Point center = new Point(Nd4j.create((long[])new long[]{cluster.getPoints().get(0).getArray().length()}));
        for (Point point : cluster.getPoints()) {
            INDArray arr = point.getArray();
            if (cluster.isInverse()) {
                center.getArray().subi(arr);
                continue;
            }
            center.getArray().addi(arr);
        }
        center.getArray().divi((Number)pointsCount);
        cluster.setCenter(center);
    }

    public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
        int pointCount = info.getPointDistancesFromCenter().size();
        if (pointCount == 0) {
            return;
        }
        double[] distances = ArrayUtils.toPrimitive((Double[])info.getPointDistancesFromCenter().values().toArray(new Double[0]));
        double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances);
        double total = MathUtils.sum(distances);
        info.setMaxPointDistanceFromCenter(max);
        info.setTotalPointDistanceFromCenter(total);
        info.setAveragePointDistanceFromCenter(total / (double)pointCount);
        info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
    }

    public static INDArray computeSquareDistancesFromNearestCluster(ClusterSet clusterSet, List<Point> points, INDArray previousDxs, ExecutorService executorService) {
        int i;
        int pointsCount = points.size();
        INDArray dxs = Nd4j.create((int)pointsCount);
        Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
        ArrayList tasks = new ArrayList();
        for (i = 0; i < pointsCount; ++i) {
            int i2 = i;
            try {
                Point point = points.get(i2);
                double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) : Math.pow(newCluster.getDistanceToCenter(point), 2.0);
                dxs.putScalar((long)i2, dist);
                continue;
            }
            catch (Throwable t) {
                log.warn("Error computing squared distance from nearest cluster", t);
            }
        }
        for (i = 0; i < pointsCount; ++i) {
            double previousMinDistance = previousDxs.getDouble((long)i);
            if (clusterSet.isInverse()) {
                if (!(dxs.getDouble((long)i) < previousMinDistance)) continue;
                dxs.putScalar((long)i, previousMinDistance);
                continue;
            }
            if (!(dxs.getDouble((long)i) > previousMinDistance)) continue;
            dxs.putScalar((long)i, previousMinDistance);
        }
        return dxs;
    }

    public static INDArray computeWeightedProbaDistancesFromNearestCluster(ClusterSet clusterSet, List<Point> points, INDArray previousDxs) {
        int pointsCount = points.size();
        INDArray dxs = Nd4j.create((int)pointsCount);
        Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
        Double sum = new Double(0.0);
        for (int i = 0; i < pointsCount; ++i) {
            Point point = points.get(i);
            double dist = Math.pow(newCluster.getDistanceToCenter(point), 2.0);
            sum = sum + dist;
            dxs.putScalar((long)i, sum.doubleValue());
        }
        return dxs;
    }

    public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
        ExecutorService executor = MultiThreadUtils.newExecutorService();
        ClusterSetInfo info = ClusterUtils.computeClusterSetInfo(clusterSet, executor);
        executor.shutdownNow();
        return info;
    }

    public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet, ExecutorService executorService) {
        int i;
        ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
        int clusterCount = clusterSet.getClusterCount();
        ArrayList tasks = new ArrayList();
        for (i = 0; i < clusterCount; ++i) {
            Cluster cluster = clusterSet.getClusters().get(i);
            try {
                info.getClustersInfos().put(cluster.getId(), ClusterUtils.computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
                continue;
            }
            catch (Throwable t) {
                log.warn("Error computing cluster set info", t);
            }
        }
        for (i = 0; i < clusterCount; ++i) {
            int clusterIdx = i;
            Cluster fromCluster = clusterSet.getClusters().get(i);
            try {
                int l = clusterSet.getClusterCount();
                for (int k = clusterIdx + 1; k < l; ++k) {
                    Cluster toCluster = clusterSet.getClusters().get(k);
                    double distance = Nd4j.getExecutioner().execAndReturn(ClusterUtils.createDistanceFunctionOp(clusterSet.getDistanceFunction(), fromCluster.getCenter().getArray(), toCluster.getCenter().getArray())).getFinalResult().doubleValue();
                    info.getDistancesBetweenClustersCenters().put((Object)fromCluster.getId(), (Object)toCluster.getId(), (Object)distance);
                }
                continue;
            }
            catch (Throwable t) {
                log.warn("Error computing distances", t);
            }
        }
        return info;
    }

    public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) {
        ClusterInfo info = new ClusterInfo(cluster.isInverse(), true);
        int j = cluster.getPoints().size();
        for (int i = 0; i < j; ++i) {
            Point point = cluster.getPoints().get(i);
            double distance = Nd4j.getExecutioner().execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, cluster.getCenter().getArray(), point.getArray())).getFinalResult().doubleValue();
            info.getPointDistancesFromCenter().put(point.getId(), distance);
            double diff = info.getTotalPointDistanceFromCenter() + distance;
            info.setTotalPointDistanceFromCenter(diff);
        }
        if (!cluster.getPoints().isEmpty()) {
            info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / (double)cluster.getPoints().size());
        }
        return info;
    }

    public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, ExecutorService executor) {
        if (optimization.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
            int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
            return splitCount > 0;
        }
        if (optimization.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
            int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
            return splitCount > 0;
        }
        return false;
    }

    public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, int count) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>(clusterSet.getClusters());
        Collections.sort(clusters, new Comparator<Cluster>(){

            @Override
            public int compare(Cluster o1, Cluster o2) {
                Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
                Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
                int comp = o1TotalDistance.compareTo(o2TotalDistance);
                return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp;
            }
        });
        return clusters.subList(0, count);
    }

    public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo info, double maximumAverageDistance) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
            if (clusterInfo == null) continue;
            if (clusterInfo.isInverse()) {
                if (!(clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance)) continue;
                clusters.add(cluster);
                continue;
            }
            if (!(clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)) continue;
            clusters.add(cluster);
        }
        return clusters;
    }

    public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo info, double maximumDistance) {
        ArrayList<Cluster> clusters = new ArrayList<Cluster>();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
            if (clusterInfo == null) continue;
            if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) {
                clusters.add(cluster);
                continue;
            }
            if (!(clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance)) continue;
            clusters.add(cluster);
        }
        return clusters;
    }

    public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
        return clustersToSplit.size();
    }

    public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
        return clustersToSplit.size();
    }

    public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
        List<Cluster> clustersToSplit = ClusterUtils.getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, maxWithinClusterDistance);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
        return clustersToSplit.size();
    }

    public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, ExecutorService executorService) {
        List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
        ClusterUtils.splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (final Cluster cluster : clusters) {
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
                        List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
                        int rank = Math.min(fartherPoints.size(), 3);
                        String pointId = fartherPoints.get(random.nextInt(rank));
                        Point point = cluster.removePoint(pointId);
                        clusterSet.addNewClusterWithCenter(point);
                    }
                    catch (Throwable t) {
                        log.warn("Error splitting clusters", t);
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, List<Cluster> clusters, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        for (final Cluster cluster : clusters) {
            tasks.add(new Runnable(){

                @Override
                public void run() {
                    try {
                        Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
                        clusterSet.addNewClusterWithCenter(point);
                    }
                    catch (Throwable t) {
                        log.warn("Error Splitting clusters (2)", t);
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(tasks, executorService);
    }

    public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int ... dimensions) {
        ReduceOp op = ClusterUtils.createDistanceFunctionOp(distanceFunction, x, y);
        op.setDimensions(dimensions);
        return op;
    }

    public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y) {
        switch (distanceFunction) {
            case COSINE_DISTANCE: {
                return new CosineDistance(x, y, new int[0]);
            }
            case COSINE_SIMILARITY: {
                return new CosineSimilarity(x, y, new int[0]);
            }
            case DOT: {
                return new Dot(x, y, new int[0]);
            }
            case EUCLIDEAN: {
                return new EuclideanDistance(x, y, new int[0]);
            }
            case JACCARD: {
                return new JaccardDistance(x, y, new int[0]);
            }
            case MANHATTAN: {
                return new ManhattanDistance(x, y, new int[0]);
            }
        }
        throw new IllegalStateException("Unknown distance function: " + (Object)((Object)distanceFunction));
    }

    private ClusterUtils() {
    }
}

