package mulan.classifier.transformation;

import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:mulan/classifier/transformation/ClassifierChain.class */
public class ClassifierChain extends TransformationBasedMultiLabelLearner {
    private int[] chain;
    protected FilteredClassifier[] ensemble;

    @Override // mulan.classifier.transformation.TransformationBasedMultiLabelLearner, mulan.classifier.MultiLabelLearnerBase
    public String globalInfo() {
        return "Class implementing the Classifier Chain (CC) algorithm.\n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // mulan.classifier.transformation.TransformationBasedMultiLabelLearner, mulan.classifier.MultiLabelLearnerBase
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff and Frank, Eibe");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Classifier Chains for Multi-label Classification");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "85");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "3");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2011");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "335--359");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        return technicalInformation;
    }

    public ClassifierChain() {
        super(new J48());
    }

    public ClassifierChain(Classifier classifier, int[] iArr) {
        super(classifier);
        this.chain = iArr;
    }

    public ClassifierChain(Classifier classifier) {
        super(classifier);
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected void buildInternal(MultiLabelInstances multiLabelInstances) throws Exception {
        if (this.chain == null) {
            this.chain = new int[this.numLabels];
            for (int i = 0; i < this.numLabels; i++) {
                this.chain[i] = i;
            }
        }
        this.numLabels = multiLabelInstances.getNumLabels();
        this.ensemble = new FilteredClassifier[this.numLabels];
        Instances dataSet = multiLabelInstances.getDataSet();
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            this.ensemble[i2] = new FilteredClassifier();
            this.ensemble[i2].setClassifier(AbstractClassifier.makeCopy(this.baseClassifier));
            int[] iArr = new int[(this.numLabels - 1) - i2];
            int i3 = 0;
            for (int i4 = 0; i4 < (this.numLabels - i2) - 1; i4++) {
                iArr[i4] = this.labelIndices[this.chain[(this.numLabels - 1) - i3]];
                i3++;
            }
            Remove remove = new Remove();
            remove.setAttributeIndicesArray(iArr);
            remove.setInputFormat(dataSet);
            remove.setInvertSelection(false);
            this.ensemble[i2].setFilter(remove);
            dataSet.setClassIndex(this.labelIndices[this.chain[i2]]);
            debug("Bulding model " + (i2 + 1) + "/" + this.numLabels);
            this.ensemble[i2].buildClassifier(dataSet);
        }
    }

    @Override // mulan.classifier.MultiLabelLearnerBase
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        boolean[] zArr = new boolean[this.numLabels];
        double[] dArr = new double[this.numLabels];
        Instance createInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());
        for (int i = 0; i < this.numLabels; i++) {
            try {
                double[] distributionForInstance = this.ensemble[i].distributionForInstance(createInstance);
                int i2 = distributionForInstance[0] > distributionForInstance[1] ? 0 : 1;
                Attribute classAttribute = this.ensemble[i].getFilter().getOutputFormat().classAttribute();
                zArr[this.chain[i]] = classAttribute.value(i2).equals("1");
                dArr[this.chain[i]] = distributionForInstance[classAttribute.indexOfValue("1")];
                createInstance.setValue(this.labelIndices[this.chain[i]], i2);
            } catch (Exception e) {
                System.out.println(e);
                return null;
            }
        }
        return new MultiLabelOutput(zArr, dArr);
    }
}
