package mulan.data;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:mulan/data/ConditionalDependenceIdentifier.class */
public class ConditionalDependenceIdentifier implements LabelPairsDependenceIdentifier, Serializable {
    private Classifier baseLearner;
    protected int seed;
    private static HashMap<String, FilteredClassifier> existingModels = null;
    private double criticalValue = 3.25d;
    private int numFolds = 10;

    public ConditionalDependenceIdentifier(Classifier classifier) {
        this.baseLearner = classifier;
        if (existingModels == null) {
            existingModels = new HashMap<>();
        }
    }

    @Override // mulan.data.LabelPairsDependenceIdentifier
    public LabelsPair[] calculateDependence(MultiLabelInstances multiLabelInstances) {
        int numLabels = multiLabelInstances.getNumLabels();
        LabelsPair[] labelsPairArr = new LabelsPair[(numLabels * (numLabels - 1)) / 2];
        int i = 0;
        for (int i2 = 0; i2 < numLabels - 1; i2++) {
            for (int i3 = i2 + 1; i3 < numLabels; i3++) {
                int[] iArr = {i2, i3};
                int[] iArr2 = {i3, i2};
                double testDependence = testDependence(iArr, multiLabelInstances, this.numFolds);
                double testDependence2 = testDependence(iArr2, multiLabelInstances, this.numFolds);
                if (testDependence >= testDependence2) {
                    int i4 = i;
                    i++;
                    labelsPairArr[i4] = new LabelsPair(iArr, testDependence);
                } else {
                    int i5 = i;
                    i++;
                    labelsPairArr[i5] = new LabelsPair(iArr2, testDependence2);
                }
            }
        }
        Arrays.sort(labelsPairArr, Collections.reverseOrder());
        return labelsPairArr;
    }

    private double testDependence(int[] iArr, MultiLabelInstances multiLabelInstances, int i) {
        double applyTtest;
        double[] dArr = null;
        double[] dArr2 = null;
        try {
            try {
                int numLabels = multiLabelInstances.getNumLabels();
                int[] labelIndices = multiLabelInstances.getLabelIndices();
                Instances[] instancesArr = new Instances[i];
                Instances[] instancesArr2 = new Instances[i];
                Evaluation[] evaluationArr = new Evaluation[i];
                Evaluation[] evaluationArr2 = new Evaluation[i];
                dArr = new double[i];
                dArr2 = new double[i];
                Instances instances = new Instances(multiLabelInstances.getDataSet());
                Random random = new Random(this.seed);
                instances.randomize(random);
                for (int i2 = 0; i2 < i; i2++) {
                    instancesArr[i2] = instances.trainCV(i, i2, random);
                    instancesArr2[i2] = instances.testCV(i, i2);
                    int i3 = labelIndices[iArr[0]];
                    int[] iArr2 = new int[numLabels - 1];
                    int i4 = 0;
                    for (int i5 = 0; i5 < numLabels; i5++) {
                        if (i5 != iArr[0]) {
                            iArr2[i4] = labelIndices[i5];
                            i4++;
                        }
                    }
                    String createKey = createKey(iArr2, instancesArr[i2].toString().hashCode());
                    FilteredClassifier buildModel = existingModels.containsKey(createKey) ? existingModels.get(createKey) : buildModel(iArr2, i3, instancesArr[i2]);
                    int[] iArr3 = new int[numLabels - 2];
                    int i6 = 0;
                    for (int i7 = 0; i7 < numLabels; i7++) {
                        if (i7 != iArr[0] && i7 != iArr[1]) {
                            iArr3[i6] = labelIndices[i7];
                            i6++;
                        }
                    }
                    FilteredClassifier buildModel2 = buildModel(iArr3, i3, instancesArr[i2]);
                    Instances prepareDatSet = prepareDatSet(iArr2, i3, instancesArr[i2]);
                    Instances prepareDatSet2 = prepareDatSet(iArr2, i3, instancesArr2[i2]);
                    evaluationArr[i2] = new Evaluation(prepareDatSet);
                    evaluationArr[i2].evaluateModel(buildModel, prepareDatSet2, new Object[0]);
                    dArr[i2] = evaluationArr[i2].pctCorrect();
                    Instances prepareDatSet3 = prepareDatSet(iArr3, i3, instancesArr[i2]);
                    Instances prepareDatSet4 = prepareDatSet(iArr3, i3, instancesArr2[i2]);
                    evaluationArr2[i2] = new Evaluation(prepareDatSet3);
                    evaluationArr2[i2].evaluateModel(buildModel2, prepareDatSet4, new Object[0]);
                    dArr2[i2] = evaluationArr2[i2].pctCorrect();
                }
                applyTtest = (dArr == null || dArr2 == null) ? -1.0d : applyTtest(dArr, dArr2);
            } catch (Exception e) {
                Logger.getLogger(ConditionalDependenceIdentifier.class.getSimpleName()).log(Level.SEVERE, (String) null, (Throwable) e);
                applyTtest = (dArr == null || dArr2 == null) ? -1.0d : applyTtest(dArr, dArr2);
            }
            return applyTtest;
        } catch (Throwable th) {
            if (dArr != null && dArr2 != null) {
                applyTtest(dArr, dArr2);
            }
            throw th;
        }
    }

    private double applyTtest(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        double d2 = 0.0d;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            d += dArr[i];
            d2 += dArr2[i];
        }
        double d3 = d / length;
        double d4 = d2 / length;
        if (d3 > d4) {
            return -1.0d;
        }
        double d5 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            d5 += Math.pow((dArr[i2] - d3) - (dArr2[i2] - d4), 2.0d);
        }
        double sqrt = (d3 - d4) * (d5 != 0.0d ? Math.sqrt((length * (length - 1)) / d5) : 0.0d);
        if (sqrt < 0.0d) {
            sqrt *= -1.0d;
        }
        return sqrt;
    }

    private FilteredClassifier buildModel(int[] iArr, int i, Instances instances) throws Exception {
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setClassifier(AbstractClassifier.makeCopy(this.baseLearner));
        Remove remove = new Remove();
        remove.setAttributeIndicesArray(iArr);
        remove.setInputFormat(instances);
        remove.setInvertSelection(false);
        filteredClassifier.setFilter(remove);
        instances.setClassIndex(i);
        filteredClassifier.buildClassifier(instances);
        existingModels.put(createKey(iArr, instances.toString().hashCode()), filteredClassifier);
        return filteredClassifier;
    }

    private String createKey(int[] iArr, int i) {
        StringBuilder sb = new StringBuilder("_");
        for (int i2 : iArr) {
            sb.append(i2);
            sb.append("_");
        }
        sb.append(i);
        return sb.toString();
    }

    private Instances prepareDatSet(int[] iArr, int i, Instances instances) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndicesArray(iArr);
        remove.setInputFormat(instances);
        remove.setInvertSelection(false);
        instances.setClassIndex(i);
        return instances;
    }

    public void setCriticalValue(double d) {
        this.criticalValue = d;
    }

    @Override // mulan.data.LabelPairsDependenceIdentifier
    public double getCriticalValue() {
        return this.criticalValue;
    }

    public int getSeed() {
        return this.seed;
    }

    public void setSeed(int i) {
        this.seed = i;
    }

    public int getNumFolds() {
        return this.numFolds;
    }

    public void setNumFolds(int i) {
        this.numFolds = i;
    }
}
