package org.deeplearning4j.clustering.algorithm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.ClusterUtils;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.deeplearning4j.clustering.iteration.IterationInfo;
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.util.MultiThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.class */
public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable {
    private static final Logger log = LoggerFactory.getLogger(BaseClusteringAlgorithm.class);
    private static final long serialVersionUID = 338231277453149972L;
    private ClusteringStrategy clusteringStrategy;
    private IterationHistory iterationHistory;
    private int currentIteration;
    private ClusterSet clusterSet;
    private List<Point> initialPoints;
    private transient ExecutorService exec;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy) {
        this.currentIteration = 0;
        this.clusteringStrategy = clusteringStrategy;
        this.exec = MultiThreadUtils.newExecutorService();
    }

    public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy) {
        return new BaseClusteringAlgorithm(clusteringStrategy);
    }

    @Override // org.deeplearning4j.clustering.algorithm.ClusteringAlgorithm
    public ClusterSet applyTo(List<Point> list) {
        resetState(list);
        initClusters();
        iterations();
        return this.clusterSet;
    }

    private void resetState(List<Point> list) {
        this.iterationHistory = new IterationHistory();
        this.currentIteration = 0;
        this.clusterSet = null;
        this.initialPoints = list;
    }

    private void iterations() {
        int i = 0;
        while (true) {
            if ((this.clusteringStrategy.getTerminationCondition() == null || this.clusteringStrategy.getTerminationCondition().isSatisfied(this.iterationHistory)) && !this.iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
                return;
            }
            this.currentIteration++;
            removePoints();
            classifyPoints();
            applyClusteringStrategy();
            i++;
            log.info("Completed clustering iteration {}", Integer.valueOf(i));
        }
    }

    protected void classifyPoints() {
        ClusterSetInfo classifyPoints = ClusterUtils.classifyPoints(this.clusterSet, this.initialPoints, this.exec);
        ClusterUtils.refreshClustersCenters(this.clusterSet, classifyPoints, this.exec);
        this.iterationHistory.getIterationsInfos().put(Integer.valueOf(this.currentIteration), new IterationInfo(this.currentIteration, classifyPoints));
    }

    protected void initClusters() {
        log.info("Generating initial clusters");
        ArrayList arrayList = new ArrayList(this.initialPoints);
        Random random = new Random();
        this.clusterSet = new ClusterSet(this.clusteringStrategy.getDistanceFunction(), this.clusteringStrategy.inverseDistanceCalculation());
        this.clusterSet.addNewClusterWithCenter((Point) arrayList.remove(random.nextInt(arrayList.size())));
        int intValue = this.clusteringStrategy.getInitialClusterCount().intValue();
        INDArray create = Nd4j.create(arrayList.size());
        create.addi(Double.valueOf(this.clusteringStrategy.inverseDistanceCalculation() ? -1.7976931348623157E308d : Double.MAX_VALUE));
        while (this.clusterSet.getClusterCount() < intValue && !arrayList.isEmpty()) {
            create = ClusterUtils.computeSquareDistancesFromNearestCluster(this.clusterSet, arrayList, create, this.exec);
            double nextFloat = random.nextFloat() * create.maxNumber().doubleValue();
            int i = 0;
            while (true) {
                if (i >= create.length()) {
                    break;
                }
                if (create.getDouble(i) >= nextFloat) {
                    this.clusterSet.addNewClusterWithCenter((Point) arrayList.remove(i));
                    create = Nd4j.create(ArrayUtils.remove(create.data().asDouble(), i));
                    break;
                }
                i++;
            }
        }
        this.iterationHistory.getIterationsInfos().put(Integer.valueOf(this.currentIteration), new IterationInfo(this.currentIteration, ClusterUtils.computeClusterSetInfo(this.clusterSet)));
    }

    protected void applyClusteringStrategy() {
        if (isStrategyApplicableNow()) {
            ClusterSetInfo mostRecentClusterSetInfo = this.iterationHistory.getMostRecentClusterSetInfo();
            if (!this.clusteringStrategy.isAllowEmptyClusters() && removeEmptyClusters(mostRecentClusterSetInfo) > 0) {
                this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
                if (this.clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) && this.clusterSet.getClusterCount() < this.clusteringStrategy.getInitialClusterCount().intValue() && ClusterUtils.splitMostSpreadOutClusters(this.clusterSet, mostRecentClusterSetInfo, this.clusteringStrategy.getInitialClusterCount().intValue() - this.clusterSet.getClusterCount(), this.exec) > 0) {
                    this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
                }
            }
            if (this.clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) {
                optimize();
            }
        }
    }

    protected void optimize() {
        this.iterationHistory.getMostRecentIterationInfo().setStrategyApplied(ClusterUtils.applyOptimization((OptimisationStrategy) this.clusteringStrategy, this.clusterSet, this.iterationHistory.getMostRecentClusterSetInfo(), this.exec));
    }

    private boolean isStrategyApplicableNow() {
        return this.clusteringStrategy.isOptimizationDefined() && this.iterationHistory.getIterationCount() != 0 && this.clusteringStrategy.isOptimizationApplicableNow(this.iterationHistory);
    }

    protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
        List<Cluster> removeEmptyClusters = this.clusterSet.removeEmptyClusters();
        clusterSetInfo.removeClusterInfos(removeEmptyClusters);
        return removeEmptyClusters.size();
    }

    protected void removePoints() {
        this.clusterSet.removePoints();
    }

    protected BaseClusteringAlgorithm() {
        this.currentIteration = 0;
    }
}
