package org.deeplearning4j.clustering;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/clustering/ClusterSet.class */
public class ClusterSet {
    private Class<? extends Accumulation> distanceFunction;
    private List<Cluster> clusters = new ArrayList();

    public ClusterSet() {
    }

    public ClusterSet(INDArray iNDArray) {
        Integer valueOf = Integer.valueOf(iNDArray.rows());
        for (Integer num = 0; num.intValue() < valueOf.intValue(); num = Integer.valueOf(num.intValue() + 1)) {
            this.clusters.add(new Cluster(iNDArray.getRow(num.intValue())));
        }
    }

    public ClusterSet(Class<? extends Accumulation> cls) {
        this.distanceFunction = cls;
    }

    public void addNewClusterWithCenter(INDArray iNDArray) {
        getClusters().add(new Cluster(iNDArray));
    }

    public INDArray getCenters() {
        INDArray create = Nd4j.create(this.clusters.size(), this.clusters.get(0).getCenter().columns());
        Integer valueOf = Integer.valueOf(this.clusters.size());
        for (Integer num = 0; num.intValue() < valueOf.intValue(); num = Integer.valueOf(num.intValue() + 1)) {
            create.putRow(num.intValue(), this.clusters.get(num.intValue()).getCenter());
        }
        return create;
    }

    public void addPoint(INDArray iNDArray) {
        nearestCluster(iNDArray).addPoint(iNDArray, true);
    }

    public void addPoint(INDArray iNDArray, boolean z) {
        nearestCluster(iNDArray).addPoint(iNDArray, z);
    }

    public void addPoints(List<INDArray> list) {
        addPoints(list, true);
    }

    public void addPoints(List<INDArray> list, boolean z) {
        Iterator<INDArray> it = list.iterator();
        while (it.hasNext()) {
            addPoint(it.next(), z);
        }
    }

    public Cluster classify(INDArray iNDArray) {
        return classify(iNDArray, this.distanceFunction);
    }

    public Cluster classify(INDArray iNDArray, Class<? extends Accumulation> cls) {
        return nearestCluster(iNDArray);
    }

    protected Cluster nearestCluster(INDArray iNDArray) {
        Cluster cluster = null;
        double d = 3.4028234663852886E38d;
        for (Cluster cluster2 : getClusters()) {
            INDArray center = cluster2.getCenter();
            if (center != null) {
                double distance = getDistance(center, iNDArray);
                if (distance < d) {
                    d = distance;
                    cluster = cluster2;
                }
            }
        }
        return cluster;
    }

    private double getDistance(INDArray iNDArray, INDArray iNDArray2) {
        return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(iNDArray, iNDArray2)).currentResult().doubleValue();
    }

    public double getDistanceFromNearestCluster(INDArray iNDArray) {
        return getDistance(nearestCluster(iNDArray).getCenter(), iNDArray);
    }

    public int getClusterCount() {
        if (getClusters() == null) {
            return 0;
        }
        return getClusters().size();
    }

    public void removePoints() {
        Iterator<Cluster> it = getClusters().iterator();
        while (it.hasNext()) {
            it.next().removePoints();
        }
    }

    public List<Cluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<Cluster> list) {
        this.clusters = list;
    }

    public Class<? extends Accumulation> getDistanceFunction() {
        return this.distanceFunction;
    }

    public void setDistanceFunction(Class<? extends Accumulation> cls) {
        this.distanceFunction = cls;
    }
}
