/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.algorithm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.algorithm.ClusteringAlgorithm;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.ClusterUtils;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.deeplearning4j.clustering.iteration.IterationInfo;
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.util.MultiThreadUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BaseClusteringAlgorithm
implements ClusteringAlgorithm,
Serializable {
    private static final Logger log = LoggerFactory.getLogger(BaseClusteringAlgorithm.class);
    private static final long serialVersionUID = 338231277453149972L;
    private ClusteringStrategy clusteringStrategy;
    private IterationHistory iterationHistory;
    private int currentIteration = 0;
    private ClusterSet clusterSet;
    private List<Point> initialPoints;
    private transient ExecutorService exec;
    private boolean useKmeansPlusPlus;

    protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
        this.clusteringStrategy = clusteringStrategy;
        this.exec = MultiThreadUtils.newExecutorService();
        this.useKmeansPlusPlus = useKmeansPlusPlus;
    }

    public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
        return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
    }

    @Override
    public ClusterSet applyTo(List<Point> points) {
        this.resetState(points);
        this.initClusters(this.useKmeansPlusPlus);
        this.iterations();
        return this.clusterSet;
    }

    private void resetState(List<Point> points) {
        this.iterationHistory = new IterationHistory();
        this.currentIteration = 0;
        this.clusterSet = null;
        this.initialPoints = points;
    }

    private void iterations() {
        int iterationCount = 0;
        while (this.clusteringStrategy.getTerminationCondition() != null && !this.clusteringStrategy.getTerminationCondition().isSatisfied(this.iterationHistory) || this.iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
            ++this.currentIteration;
            this.removePoints();
            this.classifyPoints();
            this.applyClusteringStrategy();
            log.trace("Completed clustering iteration {}", (Object)(++iterationCount));
        }
    }

    protected void classifyPoints() {
        ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(this.clusterSet, this.initialPoints, this.exec);
        ClusterUtils.refreshClustersCenters(this.clusterSet, clusterSetInfo, this.exec);
        this.iterationHistory.getIterationsInfos().put(this.currentIteration, new IterationInfo(this.currentIteration, clusterSetInfo));
    }

    protected void initClusters(boolean kMeansPlusPlus) {
        log.info("Generating initial clusters");
        ArrayList<Point> points = new ArrayList<Point>(this.initialPoints);
        Random random = Nd4j.getRandom();
        Distance distanceFn = this.clusteringStrategy.getDistanceFunction();
        int initialClusterCount = this.clusteringStrategy.getInitialClusterCount();
        this.clusterSet = new ClusterSet(distanceFn, this.clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, ((Point)points.get(0)).getArray().length()});
        this.clusterSet.addNewClusterWithCenter((Point)points.remove(random.nextInt(points.size())));
        INDArray dxs = Nd4j.create((int)points.size());
        dxs.addi((Number)(this.clusteringStrategy.inverseDistanceCalculation() ? -1.7976931348623157E308 : Double.MAX_VALUE));
        block0: while (this.clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
            dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(this.clusterSet, points, dxs, this.exec);
            double summed = Nd4j.sum((INDArray)dxs).getDouble(0L);
            double r = kMeansPlusPlus ? random.nextDouble() * summed : (double)random.nextFloat() * dxs.maxNumber().doubleValue();
            int i = 0;
            while ((long)i < dxs.length()) {
                double distance = dxs.getDouble((long)i);
                Preconditions.checkState((distance >= 0.0 ? 1 : 0) != 0, (String)"Encountered negative distance: distance function is not valid? Distance function must return values >= 0, got distance %s for function s", (Object)distance, (Object)((Object)distanceFn));
                if (dxs.getDouble((long)i) >= r) {
                    this.clusterSet.addNewClusterWithCenter((Point)points.remove(i));
                    dxs = Nd4j.create((double[])ArrayUtils.remove((double[])dxs.data().asDouble(), (int)i));
                    continue block0;
                }
                ++i;
            }
        }
        ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(this.clusterSet);
        this.iterationHistory.getIterationsInfos().put(this.currentIteration, new IterationInfo(this.currentIteration, initialClusterSetInfo));
    }

    protected void applyClusteringStrategy() {
        int removedCount;
        if (!this.isStrategyApplicableNow()) {
            return;
        }
        ClusterSetInfo clusterSetInfo = this.iterationHistory.getMostRecentClusterSetInfo();
        if (!this.clusteringStrategy.isAllowEmptyClusters() && (removedCount = this.removeEmptyClusters(clusterSetInfo)) > 0) {
            int splitCount;
            this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
            if (this.clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) && this.clusterSet.getClusterCount() < this.clusteringStrategy.getInitialClusterCount() && (splitCount = ClusterUtils.splitMostSpreadOutClusters(this.clusterSet, clusterSetInfo, this.clusteringStrategy.getInitialClusterCount() - this.clusterSet.getClusterCount(), this.exec)) > 0) {
                this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
            }
        }
        if (this.clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) {
            this.optimize();
        }
    }

    protected void optimize() {
        ClusterSetInfo clusterSetInfo = this.iterationHistory.getMostRecentClusterSetInfo();
        OptimisationStrategy optimization = (OptimisationStrategy)this.clusteringStrategy;
        boolean applied = ClusterUtils.applyOptimization(optimization, this.clusterSet, clusterSetInfo, this.exec);
        this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied);
    }

    private boolean isStrategyApplicableNow() {
        return this.clusteringStrategy.isOptimizationDefined() && this.iterationHistory.getIterationCount() != 0 && this.clusteringStrategy.isOptimizationApplicableNow(this.iterationHistory);
    }

    protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
        List<Cluster> removedClusters = this.clusterSet.removeEmptyClusters();
        clusterSetInfo.removeClusterInfos(removedClusters);
        return removedClusters.size();
    }

    protected void removePoints() {
        this.clusterSet.removePoints();
    }

    protected BaseClusteringAlgorithm() {
    }
}

