package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import smile.clustering.PartitionClustering;
import smile.data.SparseDataset;
import smile.math.Math;
import smile.math.SparseArray;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:smile/clustering/SIB.class */
public class SIB extends PartitionClustering<double[]> {
    private double distortion;
    private double[][] centroids;

    /* loaded from: input_file:smile/clustering/SIB$SIBThread.class */
    static class SIBThread implements Callable<SIB> {
        double[][] data;
        SparseDataset sparse;
        final int k;
        final int maxIter;

        SIBThread(double[][] dArr, int i, int i2) {
            this.data = (double[][]) null;
            this.sparse = null;
            this.data = dArr;
            this.k = i;
            this.maxIter = i2;
        }

        SIBThread(SparseDataset sparseDataset, int i, int i2) {
            this.data = (double[][]) null;
            this.sparse = null;
            this.sparse = sparseDataset;
            this.k = i;
            this.maxIter = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public SIB call() {
            return this.data != null ? new SIB(this.data, this.k, this.maxIter) : new SIB(this.sparse, this.k, this.maxIter);
        }
    }

    public SIB(double[][] dArr, int i) {
        this(dArr, i, 100);
    }

    public SIB(double[][] dArr, int i, int i2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid parameter k = " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        this.k = i;
        this.size = new int[i];
        this.centroids = new double[i][length2];
        this.y = seed(dArr, i, PartitionClustering.DistanceMethod.JENSEN_SHANNON_DIVERGENCE);
        for (int i3 = 0; i3 < length; i3++) {
            int[] iArr = this.size;
            int i4 = this.y[i3];
            iArr[i4] = iArr[i4] + 1;
            for (int i5 = 0; i5 < length2; i5++) {
                double[] dArr2 = this.centroids[this.y[i3]];
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + dArr[i3][i5];
            }
        }
        for (int i7 = 0; i7 < i; i7++) {
            for (int i8 = 0; i8 < length2; i8++) {
                double[] dArr3 = this.centroids[i7];
                int i9 = i8;
                dArr3[i9] = dArr3[i9] / this.size[i7];
            }
        }
        int i10 = length;
        for (int i11 = 1; i11 <= i2 && i10 > 0; i11++) {
            i10 = 0;
            for (int i12 = 0; i12 < length; i12++) {
                double d = Double.MAX_VALUE;
                int i13 = -1;
                for (int i14 = 0; i14 < i; i14++) {
                    double JensenShannonDivergence = Math.JensenShannonDivergence(dArr[i12], this.centroids[i14]);
                    if (d > JensenShannonDivergence) {
                        d = JensenShannonDivergence;
                        i13 = i14;
                    }
                }
                if (i13 != this.y[i12]) {
                    int i15 = this.y[i12];
                    if (this.size[i15] > 1) {
                        int i16 = this.size[i15] - 1;
                        for (int i17 = 0; i17 < length2; i17++) {
                            this.centroids[i15][i17] = ((this.centroids[i15][i17] * this.size[i15]) - dArr[i12][i17]) / i16;
                            if (this.centroids[i15][i17] < 0.0d) {
                                this.centroids[i15][i17] = 0.0d;
                            }
                        }
                    } else {
                        Arrays.fill(this.centroids[i15], 0.0d);
                    }
                    int i18 = this.size[i13] + 1;
                    for (int i19 = 0; i19 < length2; i19++) {
                        this.centroids[i13][i19] = ((this.centroids[i13][i19] * this.size[i13]) + dArr[i12][i19]) / i18;
                    }
                    int[] iArr2 = this.size;
                    iArr2[i15] = iArr2[i15] - 1;
                    int[] iArr3 = this.size;
                    int i20 = i13;
                    iArr3[i20] = iArr3[i20] + 1;
                    this.y[i12] = i13;
                    i10++;
                }
            }
        }
        this.distortion = 0.0d;
        for (int i21 = 0; i21 < length; i21++) {
            this.distortion += Math.JensenShannonDivergence(dArr[i21], this.centroids[this.y[i21]]);
        }
    }

    public SIB(double[][] dArr, int i, int i2, int i3) {
        SIB sib;
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + i3);
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < i3; i4++) {
            arrayList.add(new SIBThread(dArr, i, i2));
        }
        try {
            List run = MulticoreExecutor.run(arrayList);
            sib = (SIB) run.get(0);
            for (int i5 = 1; i5 < i3; i5++) {
                SIB sib2 = (SIB) run.get(i5);
                if (sib2.distortion < sib.distortion) {
                    sib = sib2;
                }
            }
        } catch (Exception e) {
            System.err.println(e);
            sib = new SIB(dArr, i, i2);
            for (int i6 = 1; i6 < i3; i6++) {
                SIB sib3 = new SIB(dArr, i, i2);
                if (sib3.distortion < sib.distortion) {
                    sib = sib3;
                }
            }
        }
        this.k = sib.k;
        this.distortion = sib.distortion;
        this.centroids = sib.centroids;
        this.y = sib.y;
        this.size = sib.size;
    }

    private static int[] seed(SparseDataset sparseDataset, int i) {
        int size = sparseDataset.size();
        int[] iArr = new int[size];
        SparseArray sparseArray = (SparseArray) sparseDataset.get(Math.randomInt(size)).x;
        double[] dArr = new double[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = Double.MAX_VALUE;
        }
        for (int i3 = 1; i3 < i; i3++) {
            for (int i4 = 0; i4 < size; i4++) {
                double JensenShannonDivergence = Math.JensenShannonDivergence((SparseArray) sparseDataset.get(i4).x, sparseArray);
                if (JensenShannonDivergence < dArr[i4]) {
                    dArr[i4] = JensenShannonDivergence;
                    iArr[i4] = i3 - 1;
                }
            }
            double random = Math.random() * Math.sum(dArr);
            double d = 0.0d;
            int i5 = 0;
            while (i5 < size) {
                d += dArr[i5];
                if (d >= random) {
                    break;
                }
                i5++;
            }
            sparseArray = (SparseArray) sparseDataset.get(i5).x;
        }
        for (int i6 = 0; i6 < size; i6++) {
            double JensenShannonDivergence2 = Math.JensenShannonDivergence((SparseArray) sparseDataset.get(i6).x, sparseArray);
            if (JensenShannonDivergence2 < dArr[i6]) {
                dArr[i6] = JensenShannonDivergence2;
                iArr[i6] = i - 1;
            }
        }
        return iArr;
    }

    public SIB(SparseDataset sparseDataset, int i) {
        this(sparseDataset, i, 100);
    }

    public SIB(SparseDataset sparseDataset, int i, int i2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid parameter k = " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int size = sparseDataset.size();
        int ncols = sparseDataset.ncols();
        this.k = i;
        this.distortion = Double.MAX_VALUE;
        this.size = new int[i];
        this.centroids = new double[i][ncols];
        this.y = seed(sparseDataset, i);
        for (int i3 = 0; i3 < size; i3++) {
            int[] iArr = this.size;
            int i4 = this.y[i3];
            iArr[i4] = iArr[i4] + 1;
            Iterator it = ((SparseArray) sparseDataset.get(i3).x).iterator();
            while (it.hasNext()) {
                SparseArray.Entry entry = (SparseArray.Entry) it.next();
                double[] dArr = this.centroids[this.y[i3]];
                int i5 = entry.i;
                dArr[i5] = dArr[i5] + entry.x;
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < ncols; i7++) {
                double[] dArr2 = this.centroids[i6];
                int i8 = i7;
                dArr2[i8] = dArr2[i8] / this.size[i6];
            }
        }
        int i9 = size;
        for (int i10 = 1; i10 <= i2 && i9 > 0; i10++) {
            i9 = 0;
            for (int i11 = 0; i11 < size; i11++) {
                double d = Double.MAX_VALUE;
                int i12 = -1;
                for (int i13 = 0; i13 < i; i13++) {
                    double JensenShannonDivergence = Math.JensenShannonDivergence((SparseArray) sparseDataset.get(i11).x, this.centroids[i13]);
                    if (d > JensenShannonDivergence) {
                        d = JensenShannonDivergence;
                        i12 = i13;
                    }
                }
                if (i12 != this.y[i11]) {
                    int i14 = this.y[i11];
                    for (int i15 = 0; i15 < ncols; i15++) {
                        double[] dArr3 = this.centroids[i12];
                        int i16 = i15;
                        dArr3[i16] = dArr3[i16] * this.size[i12];
                        double[] dArr4 = this.centroids[i14];
                        int i17 = i15;
                        dArr4[i17] = dArr4[i17] * this.size[i14];
                    }
                    Iterator it2 = ((SparseArray) sparseDataset.get(i11).x).iterator();
                    while (it2.hasNext()) {
                        SparseArray.Entry entry2 = (SparseArray.Entry) it2.next();
                        int i18 = entry2.i;
                        double d2 = entry2.x;
                        double[] dArr5 = this.centroids[i12];
                        dArr5[i18] = dArr5[i18] + d2;
                        double[] dArr6 = this.centroids[i14];
                        dArr6[i18] = dArr6[i18] - d2;
                        if (this.centroids[i14][i18] < 0.0d) {
                            this.centroids[i14][i18] = 0.0d;
                        }
                    }
                    int[] iArr2 = this.size;
                    iArr2[i14] = iArr2[i14] - 1;
                    int[] iArr3 = this.size;
                    int i19 = i12;
                    iArr3[i19] = iArr3[i19] + 1;
                    for (int i20 = 0; i20 < ncols; i20++) {
                        double[] dArr7 = this.centroids[i12];
                        int i21 = i20;
                        dArr7[i21] = dArr7[i21] / this.size[i12];
                    }
                    if (this.size[i14] > 0) {
                        for (int i22 = 0; i22 < ncols; i22++) {
                            double[] dArr8 = this.centroids[i14];
                            int i23 = i22;
                            dArr8[i23] = dArr8[i23] / this.size[i14];
                        }
                    }
                    this.y[i11] = i12;
                    i9++;
                }
            }
        }
        this.distortion = 0.0d;
        for (int i24 = 0; i24 < size; i24++) {
            this.distortion += Math.JensenShannonDivergence((SparseArray) sparseDataset.get(i24).x, this.centroids[this.y[i24]]);
        }
    }

    public SIB(SparseDataset sparseDataset, int i, int i2, int i3) {
        SIB sib;
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + i3);
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < i3; i4++) {
            arrayList.add(new SIBThread(sparseDataset, i, i2));
        }
        try {
            List run = MulticoreExecutor.run(arrayList);
            sib = (SIB) run.get(0);
            for (int i5 = 1; i5 < i3; i5++) {
                SIB sib2 = (SIB) run.get(i5);
                if (sib2.distortion < sib.distortion) {
                    sib = sib2;
                }
            }
        } catch (Exception e) {
            System.err.println(e);
            sib = new SIB(sparseDataset, i, i2);
            for (int i6 = 1; i6 < i3; i6++) {
                SIB sib3 = new SIB(sparseDataset, i, i2);
                if (sib3.distortion < sib.distortion) {
                    sib = sib3;
                }
            }
        }
        this.k = sib.k;
        this.distortion = sib.distortion;
        this.centroids = sib.centroids;
        this.y = sib.y;
        this.size = sib.size;
    }

    @Override // smile.clustering.Clustering
    public int predict(double[] dArr) {
        double d = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < this.k; i2++) {
            double JensenShannonDivergence = Math.JensenShannonDivergence(dArr, this.centroids[i2]);
            if (JensenShannonDivergence < d) {
                d = JensenShannonDivergence;
                i = i2;
            }
        }
        return i;
    }

    public int predict(SparseArray sparseArray) {
        double d = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < this.k; i2++) {
            double JensenShannonDivergence = Math.JensenShannonDivergence(sparseArray, this.centroids[i2]);
            if (JensenShannonDivergence < d) {
                d = JensenShannonDivergence;
                i = i2;
            }
        }
        return i;
    }

    public double distortion() {
        return this.distortion;
    }

    public double[][] centroids() {
        return this.centroids;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Sequential Information Bottleneck distortion: %.5f\n", Double.valueOf(this.distortion)));
        sb.append(String.format("Clusters of %d data points of dimension %d:\n", Integer.valueOf(this.y.length), Integer.valueOf(this.centroids[0].length)));
        for (int i = 0; i < this.k; i++) {
            int round = (int) Math.round((1000.0d * this.size[i]) / this.y.length);
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)\n", Integer.valueOf(i), Integer.valueOf(this.size[i]), Integer.valueOf(round / 10), Integer.valueOf(round % 10)));
        }
        return sb.toString();
    }
}
