package mulan.classifier.meta.thresholding;

import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.HammingLoss;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.TechnicalInformation;
import weka.core.Utils;

/* loaded from: input_file:mulan/classifier/meta/thresholding/SCut.class */
public class SCut extends MultiLabelMetaLearner {
    BipartitionMeasureBase measure;
    int kFoldsCV;
    double[] thresholds;

    public SCut() {
        this(new BinaryRelevance(new J48()), new HammingLoss(), 3);
    }

    public SCut(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase, int i) {
        super(multiLabelLearner);
        this.measure = bipartitionMeasureBase;
        this.kFoldsCV = i;
    }

    public SCut(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase) {
        super(multiLabelLearner);
        this.measure = bipartitionMeasureBase;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yiming Yang");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "A study of thresholding strategies for text categorization");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "137 - 145");
        technicalInformation.setValue(TechnicalInformation.Field.LOCATION, "New Orleans, Louisiana, United States");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        return technicalInformation;
    }

    private double[] computeThresholds(MultiLabelLearner multiLabelLearner, MultiLabelInstances multiLabelInstances) throws Exception {
        int i;
        double[][] dArr = new double[multiLabelInstances.getNumInstances()][this.numLabels];
        boolean[][] zArr = new boolean[multiLabelInstances.getNumInstances()][this.numLabels];
        ArrayList[] arrayListArr = new ArrayList[this.numLabels];
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            arrayListArr[i2] = new ArrayList();
        }
        for (int i3 = 0; i3 < multiLabelInstances.getNumInstances(); i3++) {
            try {
                dArr[i3] = multiLabelLearner.makePrediction(multiLabelInstances.getDataSet().instance(i3)).getConfidences();
            } catch (Exception e) {
                Logger.getLogger(SCut.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            for (int i4 = 0; i4 < this.numLabels; i4++) {
                int i5 = this.labelIndices[i4];
                zArr[i3][i4] = multiLabelInstances.getDataSet().attribute(i5).value((int) multiLabelInstances.getDataSet().instance(i3).value(i5)).equals("1");
                arrayListArr[i4].add(Double.valueOf(dArr[i3][i4]));
            }
        }
        double[] dArr2 = new double[this.numLabels];
        double[][] dArr3 = new double[3][this.numLabels];
        for (int i6 = 0; i6 < this.numLabels; i6++) {
            Collections.sort(arrayListArr[i6]);
            dArr2[i6] = 0.5d;
        }
        double d = 0.0d;
        int numInstances = multiLabelInstances.getNumInstances();
        double[] dArr4 = new double[numInstances];
        BipartitionMeasureBase[] bipartitionMeasureBaseArr = new BipartitionMeasureBase[numInstances];
        for (int i7 = 0; i7 < numInstances; i7++) {
            bipartitionMeasureBaseArr[i7] = (BipartitionMeasureBase) this.measure.makeCopy();
            bipartitionMeasureBaseArr[i7].reset();
        }
        do {
            for (int i8 = 0; i8 < this.numLabels; i8++) {
                dArr3[1][i8] = dArr3[0][i8];
            }
            for (int i9 = 0; i9 < this.numLabels; i9++) {
                double d2 = 0.0d;
                for (int i10 = numInstances - 1; i10 >= 0; i10--) {
                    bipartitionMeasureBaseArr[i10].reset();
                    if (i10 == 0) {
                        dArr2[i9] = ((Double) arrayListArr[i9].get(i10)).doubleValue();
                    } else {
                        dArr2[i9] = (((Double) arrayListArr[i9].get(i10)).doubleValue() + ((Double) arrayListArr[i9].get(i10 - 1)).doubleValue()) / 2.0d;
                    }
                    for (int i11 = 0; i11 < multiLabelInstances.getNumInstances(); i11++) {
                        boolean[] zArr2 = new boolean[this.numLabels];
                        for (int i12 = 0; i12 < this.numLabels; i12++) {
                            zArr2[i12] = dArr[i11][i12] >= dArr2[i12];
                        }
                        bipartitionMeasureBaseArr[i10].update(new MultiLabelOutput(zArr2), zArr[i11]);
                    }
                    d2 += bipartitionMeasureBaseArr[i10].getValue();
                }
                for (int i13 = 0; i13 < numInstances; i13++) {
                    dArr4[i13] = Math.abs(this.measure.getIdealValue() - bipartitionMeasureBaseArr[i13].getValue());
                }
                int minIndex = Utils.minIndex(dArr4);
                double doubleValue = minIndex == 0 ? ((Double) arrayListArr[i9].get(minIndex)).doubleValue() : (((Double) arrayListArr[i9].get(minIndex)).doubleValue() + ((Double) arrayListArr[i9].get(minIndex - 1)).doubleValue()) / 2.0d;
                dArr3[0][i9] = d2;
                dArr2[i9] = doubleValue;
                if (d == 0.0d) {
                    dArr3[2][i9] = d2;
                }
            }
            i = 0;
            for (int i14 = 0; i14 < this.numLabels; i14++) {
                if (Math.abs(dArr3[0][i14] - dArr3[1][i14]) / dArr3[2][i14] < 0.001d && d != 0.0d) {
                    i++;
                }
            }
            d += 1.0d;
        } while (i != this.numLabels);
        return dArr2;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected void buildInternal(MultiLabelInstances multiLabelInstances) throws Exception {
        if (this.kFoldsCV == 0) {
            this.baseLearner.build(multiLabelInstances);
            this.thresholds = computeThresholds(this.baseLearner, multiLabelInstances);
            return;
        }
        this.thresholds = new double[this.numLabels];
        for (int i = 0; i < this.kFoldsCV; i++) {
            MultiLabelInstances multiLabelInstances2 = new MultiLabelInstances(multiLabelInstances.getDataSet().trainCV(this.kFoldsCV, i), multiLabelInstances.getLabelsMetaData());
            MultiLabelInstances multiLabelInstances3 = new MultiLabelInstances(multiLabelInstances.getDataSet().testCV(this.kFoldsCV, i), multiLabelInstances.getLabelsMetaData());
            MultiLabelLearner makeCopy = this.baseLearner.makeCopy();
            makeCopy.build(multiLabelInstances2);
            double[] computeThresholds = computeThresholds(makeCopy, multiLabelInstances3);
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                double[] dArr = this.thresholds;
                int i3 = i2;
                dArr[i3] = dArr[i3] + computeThresholds[i2];
            }
        }
        for (int i4 = 0; i4 < this.numLabels; i4++) {
            double[] dArr2 = this.thresholds;
            int i5 = i4;
            dArr2[i5] = dArr2[i5] / this.kFoldsCV;
        }
        this.baseLearner.build(multiLabelInstances);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        MultiLabelOutput makePrediction = this.baseLearner.makePrediction(instance);
        double[] dArr = new double[this.numLabels];
        boolean[] zArr = new boolean[this.numLabels];
        if (makePrediction.hasConfidences()) {
            dArr = makePrediction.getConfidences();
            for (int i = 0; i < this.numLabels; i++) {
                if (dArr[i] >= this.thresholds[i]) {
                    zArr[i] = true;
                } else {
                    zArr[i] = false;
                }
            }
        }
        return new MultiLabelOutput(zArr, dArr);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public String globalInfo() {
        return "Class that implements the SCut method (Score-based local  optimization). It computes a separate threshold for each label based on improving a user defined performance measure.For more information, see\n\n" + getTechnicalInformation().toString();
    }
}
