package mulan.classifier.transformation;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import mulan.data.Statistics;
import mulan.transformations.BinaryRelevanceTransformation;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.Ranker;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.neighboursearch.LinearNNSearch;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:mulan/classifier/transformation/MultiLabelStacking.class */
public class MultiLabelStacking extends TransformationBasedMultiLabelLearner implements Serializable {
    private static final long serialVersionUID = 1;
    private Classifier metaClassifier;
    private Instances[] baseLevelData;
    private Instances[] metaLevelData;
    private Classifier[] baseLevelEnsemble;
    private Classifier[] metaLevelEnsemble;
    private FilteredClassifier[] metaLevelFilteredEnsemble;
    private int numFolds;
    protected Instances train;
    private double[][] baseLevelPredictions;
    private boolean normalize;
    private double[] maxProb;
    private double[] minProb;
    private boolean includeAttrs;
    private double metaPercentage;
    private int topkCorrelated;
    private int[][] selectedAttributes;
    private ASEvaluation eval;
    private LinearNNSearch lnn;
    private boolean partialBuild;

    @Override // mulan.classifier.transformation.TransformationBasedMultiLabelLearner, mulan.classifier.MultiLabelLearnerBase
    public String globalInfo() {
        return "This class is an implementation of the (BR)^2 or Multi-Label stacking method.\n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // mulan.classifier.transformation.TransformationBasedMultiLabelLearner, mulan.classifier.MultiLabelLearnerBase
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Grigorios Tsoumakas, Anastasios Dimou, Eleftherios Spyromitros, Vasileios Mezaris, Ioannis Kompatsiaris, Ioannis Vlahavas");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Correlation-Based Pruning of Stacked Binary Relevance Models for Multi-Label Learning");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proc. ECML/PKDD 2009 Workshop on Learning from Multi-Label Data (MLD'09)");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2009");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "101-116");
        technicalInformation.setValue(TechnicalInformation.Field.LOCATION, "Bled, Slovenia");
        return technicalInformation;
    }

    public MultiLabelStacking() {
        this(new J48(), new J48());
    }

    public MultiLabelStacking(Classifier classifier, Classifier classifier2) {
        super(classifier);
        this.lnn = null;
        this.metaClassifier = classifier2;
        this.numFolds = 10;
        this.metaPercentage = 1.0d;
        this.eval = null;
        this.normalize = false;
        this.includeAttrs = false;
        this.partialBuild = false;
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected void buildInternal(MultiLabelInstances multiLabelInstances) throws Exception {
        if (this.partialBuild) {
            return;
        }
        if (this.baseClassifier instanceof IBk) {
            buildBaseLevelKNN(multiLabelInstances);
        } else {
            buildBaseLevel(multiLabelInstances);
        }
        initializeMetaLevel(multiLabelInstances, this.metaClassifier, this.includeAttrs, this.metaPercentage, this.eval);
        buildMetaLevel();
    }

    /* JADX WARN: Type inference failed for: r1v19, types: [int[], int[][]] */
    public void initializeMetaLevel(MultiLabelInstances multiLabelInstances, Classifier classifier, boolean z, double d, ASEvaluation aSEvaluation) throws Exception {
        this.metaClassifier = classifier;
        this.metaLevelEnsemble = AbstractClassifier.makeCopies(classifier, this.numLabels);
        this.metaLevelData = new Instances[this.numLabels];
        this.metaLevelFilteredEnsemble = new FilteredClassifier[this.numLabels];
        this.includeAttrs = z;
        this.topkCorrelated = (int) Math.floor(d * this.numLabels);
        if (this.topkCorrelated < 1) {
            debug("Too small percentage, selecting k=1");
            this.topkCorrelated = 1;
        }
        if (this.topkCorrelated < this.numLabels) {
            this.selectedAttributes = new int[this.numLabels];
            if (aSEvaluation == null) {
                Statistics statistics = new Statistics();
                statistics.calculatePhi(multiLabelInstances);
                for (int i = 0; i < this.numLabels; i++) {
                    this.selectedAttributes[i] = statistics.topPhiCorrelatedLabels(i, this.topkCorrelated);
                }
                return;
            }
            AttributeSelection attributeSelection = new AttributeSelection();
            Ranker ranker = new Ranker();
            ranker.setNumToSelect(this.topkCorrelated);
            attributeSelection.setEvaluator(aSEvaluation);
            attributeSelection.setSearch(ranker);
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < this.numLabels; i3++) {
                    arrayList.add(this.train.attribute(this.labelIndices[i3]));
                }
                arrayList.add(this.train.attribute(this.labelIndices[i2]).copy("meta"));
                Instances instances = new Instances("Meta format", arrayList, 0);
                instances.setClassIndex(this.numLabels);
                for (int i4 = 0; i4 < this.train.numInstances(); i4++) {
                    double[] dArr = new double[this.numLabels + 1];
                    for (int i5 = 0; i5 < this.numLabels; i5++) {
                        dArr[i5] = Double.parseDouble(this.train.attribute(this.labelIndices[i5]).value((int) this.train.instance(i4).value(this.labelIndices[i5])));
                    }
                    dArr[this.numLabels] = Double.parseDouble(this.train.attribute(this.labelIndices[i2]).value((int) this.train.instance(i4).value(this.labelIndices[i2])));
                    Instance createInstance = DataUtils.createInstance(this.train.instance(i4), 1.0d, dArr);
                    createInstance.setDataset(instances);
                    instances.add(createInstance);
                }
                attributeSelection.SelectAttributes(instances);
                this.selectedAttributes[i2] = attributeSelection.selectedAttributes();
                instances.delete();
            }
        }
    }

    public void buildBaseLevel(MultiLabelInstances multiLabelInstances) throws Exception {
        this.train = new Instances(multiLabelInstances.getDataSet());
        this.baseLevelData = new Instances[this.numLabels];
        this.baseLevelEnsemble = AbstractClassifier.makeCopies(this.baseClassifier, this.numLabels);
        if (this.normalize) {
            this.maxProb = new double[this.numLabels];
            this.minProb = new double[this.numLabels];
            Arrays.fill(this.minProb, 1.0d);
        }
        this.baseLevelPredictions = new double[this.train.numInstances()][this.numLabels];
        for (int i = 0; i < this.numLabels; i++) {
            debug("Label: " + i);
            this.baseLevelData[i] = BinaryRelevanceTransformation.transformInstances(this.train, this.labelIndices, this.labelIndices[i]);
            this.baseLevelData[i] = new Instances(attachIndexes(this.baseLevelData[i]));
            Random random = new Random(serialVersionUID);
            this.baseLevelData[i].randomize(random);
            this.baseLevelData[i].stratify(this.numFolds);
            debug("Creating meta-data");
            for (int i2 = 0; i2 < this.numFolds; i2++) {
                debug("Label=" + i + ", Fold=" + i2);
                Instances trainCV = this.baseLevelData[i].trainCV(this.numFolds, i2, random);
                FilteredClassifier filteredClassifier = new FilteredClassifier();
                filteredClassifier.setClassifier(this.baseLevelEnsemble[i]);
                Remove remove = new Remove();
                remove.setAttributeIndices("first");
                remove.setInputFormat(trainCV);
                filteredClassifier.setFilter(remove);
                filteredClassifier.buildClassifier(trainCV);
                Instances testCV = this.baseLevelData[i].testCV(this.numFolds, i2);
                for (int i3 = 0; i3 < testCV.numInstances(); i3++) {
                    double[] distributionForInstance = filteredClassifier.distributionForInstance(testCV.instance(i3));
                    Attribute classAttribute = this.baseLevelData[i].classAttribute();
                    this.baseLevelPredictions[(int) testCV.instance(i3).value(0)][i] = distributionForInstance[classAttribute.indexOfValue("1")];
                    if (this.normalize) {
                        if (distributionForInstance[classAttribute.indexOfValue("1")] > this.maxProb[i]) {
                            this.maxProb[i] = distributionForInstance[classAttribute.indexOfValue("1")];
                        }
                        if (distributionForInstance[classAttribute.indexOfValue("1")] < this.minProb[i]) {
                            this.minProb[i] = distributionForInstance[classAttribute.indexOfValue("1")];
                        }
                    }
                }
            }
            this.baseLevelData[i] = detachIndexes(this.baseLevelData[i]);
            debug("Building base classifier on full data");
            this.baseLevelEnsemble[i].buildClassifier(this.baseLevelData[i]);
            this.baseLevelData[i].delete();
        }
        if (this.normalize) {
            normalizePredictions();
        }
    }

    public void buildMetaLevel() throws Exception {
        debug("Building the ensemle of the meta level classifiers");
        for (int i = 0; i < this.numLabels; i++) {
            ArrayList arrayList = new ArrayList();
            if (this.includeAttrs) {
                for (int i2 = 0; i2 < this.train.numAttributes(); i2++) {
                    arrayList.add(this.train.attribute(i2));
                }
            } else {
                for (int i3 = 0; i3 < this.numLabels; i3++) {
                    arrayList.add(this.train.attribute(this.labelIndices[i3]));
                }
            }
            arrayList.add(this.train.attribute(this.labelIndices[i]).copy("meta"));
            this.metaLevelData[i] = new Instances("Meta format", arrayList, 0);
            this.metaLevelData[i].setClassIndex(this.metaLevelData[i].numAttributes() - 1);
            for (int i4 = 0; i4 < this.train.numInstances(); i4++) {
                double[] dArr = new double[this.metaLevelData[i].numAttributes()];
                if (this.includeAttrs) {
                    for (int i5 = 0; i5 < this.featureIndices.length; i5++) {
                        dArr[i5] = this.train.instance(i4).value(this.featureIndices[i5]);
                    }
                    System.arraycopy(this.baseLevelPredictions[i4], 0, dArr, this.train.numAttributes() - this.numLabels, this.numLabels);
                } else {
                    System.arraycopy(this.baseLevelPredictions[i4], 0, dArr, 0, this.numLabels);
                }
                dArr[dArr.length - 1] = Double.parseDouble(this.train.attribute(this.labelIndices[i]).value((int) this.train.instance(i4).value(this.labelIndices[i])));
                Instance createInstance = DataUtils.createInstance(this.train.instance(i4), 1.0d, dArr);
                createInstance.setDataset(this.metaLevelData[i]);
                if (dArr[dArr.length - 1] > 0.5d) {
                    createInstance.setClassValue("1");
                } else {
                    createInstance.setClassValue("0");
                }
                this.metaLevelData[i].add(createInstance);
            }
            this.metaLevelFilteredEnsemble[i] = new FilteredClassifier();
            this.metaLevelFilteredEnsemble[i].setClassifier(this.metaLevelEnsemble[i]);
            Remove remove = new Remove();
            if (this.topkCorrelated < this.numLabels) {
                remove.setAttributeIndicesArray(this.selectedAttributes[i]);
            } else {
                remove.setAttributeIndices("first-last");
            }
            remove.setInvertSelection(true);
            remove.setInputFormat(this.metaLevelData[i]);
            this.metaLevelFilteredEnsemble[i].setFilter(remove);
            debug("Building classifier for meta training set" + i);
            this.metaLevelFilteredEnsemble[i].buildClassifier(this.metaLevelData[i]);
            this.metaLevelData[i].delete();
        }
    }

    public void buildBaseLevelKNN(MultiLabelInstances multiLabelInstances) throws Exception {
        this.train = new Instances(multiLabelInstances.getDataSet());
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        euclideanDistance.setDontNormalize(false);
        String str = "";
        for (int i = 0; i < this.numLabels - 1; i++) {
            str = str + (this.labelIndices[i] + 1) + ",";
        }
        euclideanDistance.setAttributeIndices(str + (this.labelIndices[this.numLabels - 1] + 1));
        euclideanDistance.setInvertSelection(true);
        this.lnn = new LinearNNSearch();
        this.lnn.setSkipIdentical(true);
        this.lnn.setDistanceFunction(euclideanDistance);
        this.lnn.setInstances(this.train);
        this.lnn.setMeasurePerformance(false);
        this.baseLevelPredictions = new double[this.train.numInstances()][this.numLabels];
        int knn = this.baseClassifier.getKNN();
        for (int i2 = 0; i2 < this.train.numInstances(); i2++) {
            Instances instances = new Instances(this.lnn.kNearestNeighbours(this.train.instance(i2), knn));
            for (int i3 = 0; i3 < this.numLabels; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < knn; i4++) {
                    if (this.train.attribute(this.labelIndices[i3]).value((int) instances.instance(i4).value(this.labelIndices[i3])).equals("1")) {
                        d += 1.0d;
                    }
                }
                this.baseLevelPredictions[i2][i3] = d / knn;
            }
        }
    }

    private void normalizePredictions() {
        for (int i = 0; i < this.baseLevelPredictions.length; i++) {
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                this.baseLevelPredictions[i][i2] = (this.baseLevelPredictions[i][i2] - (this.minProb[i2] / this.maxProb[i2])) - this.minProb[i2];
            }
        }
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        boolean[] zArr = new boolean[this.numLabels];
        double[] dArr = new double[this.numLabels];
        double[] dArr2 = new double[this.numLabels];
        if (this.baseClassifier instanceof IBk) {
            int knn = this.baseClassifier.getKNN();
            Instances instances = new Instances(this.lnn.kNearestNeighbours(instance, knn));
            for (int i = 0; i < this.numLabels; i++) {
                double d = 0.0d;
                for (int i2 = 0; i2 < knn; i2++) {
                    if (Utils.eq(Double.parseDouble(this.train.attribute(this.labelIndices[i]).value((int) instances.instance(i2).value(this.labelIndices[i]))), 1.0d)) {
                        d += 1.0d;
                    }
                }
                dArr2[i] = d / knn;
            }
        } else {
            for (int i3 = 0; i3 < this.numLabels; i3++) {
                Instance transformInstance = BinaryRelevanceTransformation.transformInstance(instance, this.labelIndices, this.labelIndices[i3]);
                transformInstance.setDataset(this.baseLevelData[i3]);
                dArr2[i3] = this.baseLevelEnsemble[i3].distributionForInstance(transformInstance)[this.baseLevelData[i3].classAttribute().indexOfValue("1")];
            }
        }
        double[] dArr3 = new double[this.numLabels + 1];
        if (this.includeAttrs) {
            dArr3 = new double[instance.numAttributes() + 1];
            for (int i4 = 0; i4 < this.featureIndices.length; i4++) {
                dArr3[i4] = instance.value(this.featureIndices[i4]);
            }
            System.arraycopy(dArr2, 0, dArr3, instance.numAttributes() - this.numLabels, dArr2.length);
        } else {
            System.arraycopy(dArr2, 0, dArr3, 0, dArr2.length);
        }
        for (int i5 = 0; i5 < this.numLabels; i5++) {
            dArr3[dArr3.length - 1] = 0.0d;
            try {
                double[] distributionForInstance = this.metaLevelFilteredEnsemble[i5].distributionForInstance(DataUtils.createInstance(instance, 1.0d, dArr3));
                int i6 = distributionForInstance[0] > distributionForInstance[1] ? 0 : 1;
                Attribute classAttribute = this.metaLevelData[i5].classAttribute();
                zArr[i5] = classAttribute.value(i6).equals("1");
                dArr[i5] = distributionForInstance[classAttribute.indexOfValue("1")];
            } catch (Exception e) {
                System.out.println(e);
                return null;
            }
        }
        return new MultiLabelOutput(zArr, dArr);
    }

    protected Instances attachIndexes(Instances instances) {
        ArrayList arrayList = new ArrayList(instances.numAttributes() + 1);
        for (int i = 0; i < instances.numAttributes(); i++) {
            arrayList.add(instances.attribute(i));
        }
        arrayList.add(0, new Attribute("Index"));
        Instances instances2 = new Instances("Meta format", arrayList, 0);
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            Instance instance = (Instance) instances.instance(i2).copy();
            instance.setDataset((Instances) null);
            instance.insertAttributeAt(0);
            instance.setValue(0, i2);
            instances2.add(instance);
        }
        instances2.setClassIndex(instances.classIndex() + 1);
        return instances2;
    }

    protected Instances detachIndexes(Instances instances) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndices("first");
        remove.setInputFormat(instances);
        return Filter.useFilter(instances, remove);
    }

    public void saveObject(String str) {
        try {
            new ObjectOutputStream(new FileOutputStream(str)).writeObject(this);
        } catch (IOException e) {
            Logger.getLogger(MultiLabelStacking.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public void setIncludeAttrs(boolean z) {
        this.includeAttrs = z;
    }

    public void setMetaPercentage(double d) {
        this.metaPercentage = d;
    }

    public void setEval(ASEvaluation aSEvaluation) {
        this.eval = aSEvaluation;
    }

    public void setMetaAlgorithm(Classifier classifier) throws Exception {
        this.metaClassifier = classifier;
        this.metaLevelEnsemble = AbstractClassifier.makeCopies(classifier, this.numLabels);
    }

    public void setPartialBuild(boolean z) {
        this.partialBuild = z;
    }

    public int getTopkCorrelated() {
        return this.topkCorrelated;
    }
}
