package mulan.classifier.meta.thresholding;

import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.core.MulanRuntimeException;
import mulan.data.InvalidDataFormatException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.TechnicalInformation;
import weka.core.Utils;

/* loaded from: input_file:mulan/classifier/meta/thresholding/RCut.class */
public class RCut extends MultiLabelMetaLearner {
    private int t;
    private BipartitionMeasureBase measure;
    private int folds;
    private MultiLabelLearner foldLearner;

    public RCut() {
        this(new BinaryRelevance(new J48()));
    }

    public RCut(MultiLabelLearner multiLabelLearner) {
        super(multiLabelLearner);
        this.t = 0;
    }

    public RCut(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase, int i) {
        super(multiLabelLearner);
        this.t = 0;
        this.measure = bipartitionMeasureBase;
        this.folds = i;
        try {
            this.foldLearner = multiLabelLearner.makeCopy();
        } catch (Exception e) {
            Logger.getLogger(RCut.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public RCut(MultiLabelLearner multiLabelLearner, BipartitionMeasureBase bipartitionMeasureBase) {
        super(multiLabelLearner);
        this.t = 0;
        this.measure = bipartitionMeasureBase;
    }

    private void autoTuneThreshold(MultiLabelInstances multiLabelInstances, BipartitionMeasureBase bipartitionMeasureBase, int i) throws InvalidDataFormatException, Exception {
        if (i < 2) {
            throw new IllegalArgumentException("folds should be more than 1");
        }
        double[] dArr = new double[this.numLabels + 1];
        LabelsMetaData labelsMetaData = multiLabelInstances.getLabelsMetaData();
        MultiLabelLearner makeCopy = this.foldLearner.makeCopy();
        for (int i2 = 0; i2 < i; i2++) {
            MultiLabelInstances multiLabelInstances2 = new MultiLabelInstances(multiLabelInstances.getDataSet().trainCV(i, i2), labelsMetaData);
            MultiLabelInstances multiLabelInstances3 = new MultiLabelInstances(multiLabelInstances.getDataSet().testCV(i, i2), labelsMetaData);
            makeCopy.build(multiLabelInstances2);
            double[] computeThreshold = computeThreshold(makeCopy, multiLabelInstances3, bipartitionMeasureBase);
            for (int i3 = 0; i3 < computeThreshold.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + computeThreshold[i3];
            }
        }
        this.t = Utils.minIndex(dArr);
    }

    private double[] computeThreshold(MultiLabelLearner multiLabelLearner, MultiLabelInstances multiLabelInstances, BipartitionMeasureBase bipartitionMeasureBase) throws Exception {
        double[] dArr = new double[this.numLabels + 1];
        bipartitionMeasureBase.reset();
        for (int i = 0; i < multiLabelInstances.getNumInstances(); i++) {
            Instance instance = multiLabelInstances.getDataSet().instance(i);
            if (!multiLabelInstances.hasMissingLabels(instance)) {
                MultiLabelOutput makePrediction = multiLabelLearner.makePrediction(instance);
                boolean[] zArr = new boolean[this.numLabels];
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    int i3 = this.labelIndices[i2];
                    zArr[i2] = instance.attribute(i3).value((int) instance.value(i3)).equals("1");
                }
                int[] ranking = makePrediction.getRanking();
                for (int i4 = 0; i4 <= this.numLabels; i4++) {
                    boolean[] zArr2 = new boolean[this.numLabels];
                    for (int i5 = 0; i5 < this.numLabels; i5++) {
                        if (ranking[i5] <= i4) {
                            zArr2[i5] = true;
                        }
                    }
                }
            }
        }
        return dArr;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected void buildInternal(MultiLabelInstances multiLabelInstances) throws Exception {
        this.baseLearner.build(multiLabelInstances);
        if (!this.baseLearner.makePrediction(multiLabelInstances.getDataSet().firstInstance()).hasRanking()) {
            throw new MulanRuntimeException("Learner is not a ranker");
        }
        if (this.measure == null) {
            this.t = (int) Math.round(multiLabelInstances.getCardinality());
            this.t = 2;
        } else if (this.folds == 0) {
            this.t = Utils.minIndex(computeThreshold(this.baseLearner, multiLabelInstances, this.measure));
        } else {
            autoTuneThreshold(multiLabelInstances, this.measure, this.folds);
        }
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yiming Yang");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "A study of thresholding strategies for text categorization");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "137 - 145");
        technicalInformation.setValue(TechnicalInformation.Field.LOCATION, "New Orleans, Louisiana, United States");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        return technicalInformation;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
        int[] ranking = this.baseLearner.makePrediction(instance).getRanking();
        boolean[] zArr = new boolean[this.numLabels];
        for (int i = 0; i < this.numLabels; i++) {
            if (ranking[i] <= this.t) {
                zArr[i] = true;
            } else {
                zArr[i] = false;
            }
        }
        return new MultiLabelOutput(zArr);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase, mulan.classifier.MultiLabelLearner
    public void setDebug(boolean z) {
        super.setDebug(z);
        this.baseLearner.setDebug(z);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    public String globalInfo() {
        return "Classs that implements RCut(Rank-based cut). It selects the k top ranked labels for each instance, where k is a parameter provided by the user or automatically tuned." + getTechnicalInformation().toString();
    }
}
