package mulan.classifier.meta.thresholding;

import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.core.MulanRuntimeException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.HammingLoss;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.TechnicalInformation;
import weka.core.Utils;

/* loaded from: input_file:mulan/classifier/meta/thresholding/OneThreshold.class */
public class OneThreshold extends MultiLabelMetaLearner {
    private double threshold;
    private BipartitionMeasureBase measure;
    private int folds;
    private MultiLabelLearner foldLearner;

    public OneThreshold() {
        this(new BinaryRelevance(new J48()), new HammingLoss(), 3);
    }

    public OneThreshold(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase, int i) {
        super(multiLabelLearner);
        this.folds = 0;
        if (i < 2) {
            throw new IllegalArgumentException("folds should be more than 1");
        }
        this.measure = bipartitionMeasureBase;
        this.folds = i;
        try {
            this.foldLearner = multiLabelLearner.makeCopy();
        } catch (Exception e) {
            Logger.getLogger(OneThreshold.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public OneThreshold(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase) {
        super(multiLabelLearner);
        this.folds = 0;
        this.measure = bipartitionMeasureBase;
    }

    private double computeThreshold(MultiLabelLearner multiLabelLearner, MultiLabelInstances multiLabelInstances, BipartitionMeasureBase bipartitionMeasureBase, double d, double d2, double d3) throws Exception {
        int rint = (int) Math.rint(((d3 - d) / d2) + 1.0d);
        double[] dArr = new double[rint];
        BipartitionMeasureBase[] bipartitionMeasureBaseArr = new BipartitionMeasureBase[rint];
        for (int i = 0; i < rint; i++) {
            bipartitionMeasureBaseArr[i] = (BipartitionMeasureBase) bipartitionMeasureBase.makeCopy();
            bipartitionMeasureBaseArr[i].reset();
        }
        boolean[] zArr = new boolean[rint];
        Arrays.fill(zArr, false);
        for (int i2 = 0; i2 < multiLabelInstances.getNumInstances(); i2++) {
            Instance instance = multiLabelInstances.getDataSet().instance(i2);
            if (!multiLabelInstances.hasMissingLabels(instance)) {
                MultiLabelOutput makePrediction = multiLabelLearner.makePrediction(instance);
                boolean[] zArr2 = new boolean[this.numLabels];
                for (int i3 = 0; i3 < this.numLabels; i3++) {
                    int i4 = this.labelIndices[i3];
                    zArr2[i3] = instance.attribute(i4).value((int) instance.value(i4)).equals("1");
                }
                double[] confidences = makePrediction.getConfidences();
                int i5 = 0;
                double d4 = d;
                while (d4 <= d3) {
                    boolean[] zArr3 = new boolean[this.numLabels];
                    for (int i6 = 0; i6 < this.numLabels; i6++) {
                        if (confidences[i6] >= d4) {
                            zArr3[i6] = true;
                        }
                    }
                    try {
                        bipartitionMeasureBaseArr[i5].update(new MultiLabelOutput(zArr3), zArr2);
                    } catch (MulanRuntimeException e) {
                        zArr[i5] = true;
                    }
                    d4 += d2;
                    i5++;
                }
            }
        }
        for (int i7 = 0; i7 < rint; i7++) {
            if (zArr[i7]) {
                dArr[i7] = Double.MAX_VALUE;
            } else {
                dArr[i7] = Math.abs(bipartitionMeasureBase.getIdealValue() - bipartitionMeasureBaseArr[i7].getValue());
            }
        }
        return d + (Utils.minIndex(dArr) * d2);
    }

    private double computeThreshold(MultiLabelLearner multiLabelLearner, MultiLabelInstances multiLabelInstances, BipartitionMeasureBase bipartitionMeasureBase) throws Exception {
        double computeThreshold = computeThreshold(multiLabelLearner, multiLabelInstances, bipartitionMeasureBase, 0.0d, 0.1d, 1.0d);
        debug("1st stage threshold = " + computeThreshold);
        double computeThreshold2 = computeThreshold(multiLabelLearner, multiLabelInstances, bipartitionMeasureBase, computeThreshold - 0.05d, 0.01d, computeThreshold + 0.05d);
        debug("2nd stage threshold = " + computeThreshold2);
        return computeThreshold2;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected void buildInternal(MultiLabelInstances multiLabelInstances) throws Exception {
        this.baseLearner.build(multiLabelInstances);
        if (this.folds == 0) {
            this.threshold = computeThreshold(this.baseLearner, multiLabelInstances, this.measure);
            return;
        }
        LabelsMetaData labelsMetaData = multiLabelInstances.getLabelsMetaData();
        double[] dArr = new double[this.folds];
        for (int i = 0; i < this.folds; i++) {
            MultiLabelInstances multiLabelInstances2 = new MultiLabelInstances(multiLabelInstances.getDataSet().trainCV(this.folds, i), labelsMetaData);
            MultiLabelInstances multiLabelInstances3 = new MultiLabelInstances(multiLabelInstances.getDataSet().testCV(this.folds, i), labelsMetaData);
            MultiLabelLearner makeCopy = this.foldLearner.makeCopy();
            makeCopy.build(multiLabelInstances2);
            dArr[i] = computeThreshold(makeCopy, multiLabelInstances3, this.measure);
        }
        this.threshold = Utils.mean(dArr);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
        MultiLabelOutput makePrediction = this.baseLearner.makePrediction(instance);
        double[] confidences = makePrediction.getConfidences();
        boolean[] zArr = new boolean[this.numLabels];
        for (int i = 0; i < this.numLabels; i++) {
            if (confidences[i] >= this.threshold) {
                zArr[i] = true;
            } else {
                zArr[i] = false;
            }
        }
        return new MultiLabelOutput(zArr, makePrediction.getConfidences());
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2008");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Multi-label Classification Using Ensembles of Pruned Sets");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Data Mining, 2008. ICDM '08. Eighth IEEE International Conference on");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "995-1000");
        technicalInformation.setValue(TechnicalInformation.Field.LOCATION, "Pisa, Italy");
        return technicalInformation;
    }

    public double getThreshold() {
        return this.threshold;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public String globalInfo() {
        return "Class that estimates a single threshold for all labels and examples. For more information, see\n\n" + getTechnicalInformation().toString();
    }
}
