package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.math.BigInteger;
import java.util.Arrays;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.classifiers.trees.LimAttHoeffdingTree;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:moa/classifiers/meta/LimAttClassifier.class */
public class LimAttClassifier extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    protected Classifier[] ensemble;
    protected ADWIN[] ADError;
    protected int numberOfChangesDetected;
    protected int[][] matrixCodes;
    protected double[][] weightAttribute;
    protected boolean reset;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.LimAttHoeffdingTree");
    public IntOption numAttributesOption = new IntOption("numAttributes", 'n', "The number of attributes to use per model.", 1, 1, Integer.MAX_VALUE);
    public FloatOption weightShrinkOption = new FloatOption("weightShrink", 'w', "The number to multiply the weight misclassified counts.", 0.5d, 0.0d, 3.4028234663852886E38d);
    public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a', "Delta of Adwin change detection", 0.002d, 0.0d, 1.0d);
    public FloatOption oddsOffsetOption = new FloatOption("oddsOffset", 'o', "Offset for odds to avoid probabilities that are zero.", 0.001d, 0.0d, 3.4028234663852886E38d);
    public FlagOption pruneOption = new FlagOption("prune", 'x', "Enable pruning.");
    public FlagOption bigTreesOption = new FlagOption("bigTrees", 'b', "Use m-n attributes on the trees.");
    public IntOption numEnsemblePruningOption = new IntOption("numEnsemblePruning", 'm', "The pruned number of classifiers to use to predict.", 10, 1, Integer.MAX_VALUE);
    public FlagOption adwinReplaceWorstClassifierOption = new FlagOption("adwinReplaceWorstClassifier", 'z', "When one Adwin detects change, replace worst classifier.");
    protected boolean initMatrixCodes = false;
    protected boolean initClassifiers = false;
    protected int numberAttributes = 1;
    protected int numInstances = 0;
    public FloatOption learningRatioOption = new FloatOption("learningRatio", 'r', "Learning ratio", 1.0d);
    public FloatOption penaltyFactorOption = new FloatOption("lambda", 'p', "Lambda", 0.0d);
    public IntOption initialNumInstancesOption = new IntOption("initialNumInstances", 'i', "initialNumInstances", 10);

    /* loaded from: input_file:moa/classifiers/meta/LimAttClassifier$CombinationGenerator.class */
    public class CombinationGenerator {
        private int[] a;
        private int n;
        private int r;
        private BigInteger numLeft;
        private BigInteger total;

        public CombinationGenerator(int i, int i2) {
            if (i2 > i) {
                throw new IllegalArgumentException();
            }
            if (i < 1) {
                throw new IllegalArgumentException();
            }
            this.n = i;
            this.r = i2;
            this.a = new int[i2];
            this.total = getFactorial(i).divide(getFactorial(i2).multiply(getFactorial(i - i2)));
            reset();
        }

        public void reset() {
            for (int i = 0; i < this.a.length; i++) {
                this.a[i] = i;
            }
            this.numLeft = new BigInteger(this.total.toString());
        }

        public BigInteger getNumLeft() {
            return this.numLeft;
        }

        public boolean hasMore() {
            return this.numLeft.compareTo(BigInteger.ZERO) == 1;
        }

        public BigInteger getTotal() {
            return this.total;
        }

        private BigInteger getFactorial(int i) {
            BigInteger bigInteger = BigInteger.ONE;
            for (int i2 = i; i2 > 1; i2--) {
                bigInteger = bigInteger.multiply(new BigInteger(Integer.toString(i2)));
            }
            return bigInteger;
        }

        public int[] getNext() {
            if (this.numLeft.equals(this.total)) {
                this.numLeft = this.numLeft.subtract(BigInteger.ONE);
                int[] iArr = new int[this.a.length];
                for (int i = 0; i < this.a.length; i++) {
                    iArr[i] = this.a[i];
                }
                return iArr;
            }
            int i2 = this.r - 1;
            while (this.a[i2] == (this.n - this.r) + i2) {
                i2--;
            }
            this.a[i2] = this.a[i2] + 1;
            for (int i3 = i2 + 1; i3 < this.r; i3++) {
                this.a[i3] = (this.a[i2] + i3) - i2;
            }
            this.numLeft = this.numLeft.subtract(BigInteger.ONE);
            int[] iArr2 = new int[this.a.length];
            for (int i4 = 0; i4 < this.a.length; i4++) {
                iArr2[i4] = this.a[i4];
            }
            return iArr2;
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Ensemble Combining Restricted Hoeffding Trees using Stacking";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.initClassifiers = true;
        this.reset = true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        int numClasses = instance.numClasses();
        if (this.initClassifiers) {
            this.numberAttributes = this.numAttributesOption.getValue();
            if (this.bigTreesOption.isSet()) {
                this.numberAttributes = (instance.numAttributes() - 1) - this.numAttributesOption.getValue();
            }
            CombinationGenerator combinationGenerator = new CombinationGenerator(instance.numAttributes() - 1, this.numberAttributes);
            this.ensemble = new Classifier[combinationGenerator.getTotal().intValue()];
            Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
            classifier.resetLearning();
            for (int i = 0; i < this.ensemble.length; i++) {
                this.ensemble[i] = classifier.copy();
            }
            this.ADError = new ADWIN[this.ensemble.length];
            for (int i2 = 0; i2 < this.ensemble.length; i2++) {
                this.ADError[i2] = new ADWIN(this.deltaAdwinOption.getValue());
            }
            this.numberOfChangesDetected = 0;
            int i3 = 0;
            if (classifier instanceof LimAttHoeffdingTree) {
                while (combinationGenerator.hasMore()) {
                    ((LimAttHoeffdingTree) this.ensemble[i3]).setlistAttributes(combinationGenerator.getNext());
                    i3++;
                }
            }
            this.initClassifiers = false;
        }
        boolean z = false;
        Instance copy = instance.copy();
        double[][] dArr = new double[this.ensemble.length + 1][numClasses];
        for (int i4 = 0; i4 < this.ensemble.length; i4++) {
            double[] dArr2 = new double[numClasses];
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                dArr2[i5] = this.oddsOffsetOption.getValue();
            }
            double[] votesForInstance = this.ensemble[i4].getVotesForInstance(instance);
            double sum = Utils.sum(votesForInstance);
            if (Double.isNaN(sum) || sum <= 0.0d) {
                for (int i6 = 0; i6 < votesForInstance.length; i6++) {
                    votesForInstance[i6] = 0.0d;
                }
            } else {
                for (int i7 = 0; i7 < votesForInstance.length; i7++) {
                    int i8 = i7;
                    votesForInstance[i8] = votesForInstance[i8] / sum;
                }
            }
            double value = numClasses * this.oddsOffsetOption.getValue();
            for (int i9 = 0; i9 < votesForInstance.length; i9++) {
                int i10 = i9;
                dArr2[i10] = dArr2[i10] + votesForInstance[i9];
                value += votesForInstance[i9];
            }
            for (int i11 = 0; i11 < votesForInstance.length; i11++) {
                dArr[i4][i11] = Math.log(dArr2[i11] / (value - dArr2[i11]));
            }
        }
        if (this.adwinReplaceWorstClassifierOption.isSet()) {
            for (int i12 = 0; i12 < this.ensemble.length; i12++) {
                boolean correctlyClassifies = this.ensemble[i12].correctlyClassifies(copy);
                double estimation = this.ADError[i12].getEstimation();
                if (this.ADError[i12].setInput(correctlyClassifies ? 0.0d : 1.0d) && this.ADError[i12].getEstimation() > estimation) {
                    z = true;
                }
            }
            if (z) {
                this.numberOfChangesDetected++;
                double d = 0.0d;
                int i13 = -1;
                for (int i14 = 0; i14 < this.ensemble.length; i14++) {
                    if (d < this.ADError[i14].getEstimation()) {
                        d = this.ADError[i14].getEstimation();
                        i13 = i14;
                    }
                }
                if (i13 != -1) {
                    this.ensemble[i13].resetLearning();
                    this.ADError[i13] = new ADWIN(this.deltaAdwinOption.getValue());
                    for (int i15 = 0; i15 < instance.numClasses(); i15++) {
                        this.weightAttribute[i15][i13] = 0.0d;
                    }
                }
            }
        } else {
            for (int i16 = 0; i16 < this.ensemble.length; i16++) {
                boolean correctlyClassifies2 = this.ensemble[i16].correctlyClassifies(copy);
                double estimation2 = this.ADError[i16].getEstimation();
                if (this.ADError[i16].setInput(correctlyClassifies2 ? 0.0d : 1.0d)) {
                    this.numInstances = this.initialNumInstancesOption.getValue();
                    if (this.ADError[i16].getEstimation() > estimation2) {
                        this.numberOfChangesDetected++;
                        this.ensemble[i16].resetLearning();
                        this.ADError[i16] = new ADWIN(this.deltaAdwinOption.getValue());
                        for (int i17 = 0; i17 < instance.numClasses(); i17++) {
                            this.weightAttribute[i17][i16] = 0.0d;
                        }
                    }
                }
            }
        }
        trainOnInstanceImplPerceptron(instance.numClasses(), (int) instance.classValue(), dArr);
        for (int i18 = 0; i18 < this.ensemble.length; i18++) {
            this.ensemble[i18].trainOnInstance(instance);
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        if (this.initClassifiers) {
            return new double[0];
        }
        int numClasses = instance.numClasses();
        int length = this.ensemble.length;
        if (this.pruneOption.isSet()) {
            length = this.numEnsemblePruningOption.getValue();
        }
        double[][] dArr = new double[length + 1][numClasses];
        int[] iArr = new int[length];
        if (this.pruneOption.isSet()) {
            double[] dArr2 = new double[this.ensemble.length];
            for (int i = 0; i < numClasses; i++) {
                for (int i2 = 0; i2 < this.ensemble.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + this.weightAttribute[i][i2];
                }
            }
            Arrays.sort(dArr2);
            double d = dArr2[this.ensemble.length - length];
            int i4 = 0;
            for (int i5 = 0; i5 < this.ensemble.length; i5++) {
                if (dArr2[i5] >= d && i4 < length) {
                    iArr[i4] = i5;
                    i4++;
                }
            }
        } else {
            for (int i6 = 0; i6 < length; i6++) {
                iArr[i6] = i6;
            }
        }
        for (int i7 = 0; i7 < length; i7++) {
            int i8 = iArr[i7];
            double[] dArr3 = new double[numClasses];
            for (int i9 = 0; i9 < dArr3.length; i9++) {
                dArr3[i9] = this.oddsOffsetOption.getValue();
            }
            double[] votesForInstance = this.ensemble[i8].getVotesForInstance(instance);
            double sum = Utils.sum(votesForInstance);
            if (Double.isNaN(sum) || sum <= 0.0d) {
                for (int i10 = 0; i10 < votesForInstance.length; i10++) {
                    votesForInstance[i10] = 0.0d;
                }
            } else {
                for (int i11 = 0; i11 < votesForInstance.length; i11++) {
                    int i12 = i11;
                    votesForInstance[i12] = votesForInstance[i12] / sum;
                }
            }
            double value = numClasses * this.oddsOffsetOption.getValue();
            for (int i13 = 0; i13 < votesForInstance.length; i13++) {
                int i14 = i13;
                dArr3[i14] = dArr3[i14] + votesForInstance[i13];
                value += votesForInstance[i13];
            }
            for (int i15 = 0; i15 < votesForInstance.length; i15++) {
                dArr[i7][i15] = Math.log(dArr3[i15] / (value - dArr3[i15]));
            }
        }
        return getVotesForInstancePerceptron(dArr, iArr, instance.numClasses());
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[2];
        measurementArr[0] = new Measurement("ensemble size", this.ensemble != null ? this.ensemble.length : 0.0d);
        measurementArr[1] = new Measurement("change detections", this.numberOfChangesDetected);
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }

    public void trainOnInstanceImplPerceptron(int i, int i2, double[][] dArr) {
        if (this.reset) {
            this.reset = false;
            this.weightAttribute = new double[i][dArr.length];
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < dArr.length - 1; i4++) {
                    this.weightAttribute[i3][i4] = 1.0d / (dArr.length - 1.0d);
                }
            }
            this.numInstances = this.initialNumInstancesOption.getValue();
        }
        double value = (this.learningRatioOption.getValue() * 2.0d) / ((this.numInstances + (dArr.length - 1)) + 2.0d);
        double value2 = this.penaltyFactorOption.getValue();
        this.numInstances++;
        double[] dArr2 = new double[i];
        for (int i5 = 0; i5 < i; i5++) {
            dArr2[i5] = prediction(dArr, i5);
        }
        int i6 = 0;
        while (i6 < i) {
            double d = ((i6 == i2 ? 1.0d : 0.0d) - dArr2[i6]) * dArr2[i6] * (1.0d - dArr2[i6]);
            for (int i7 = 0; i7 < this.ensemble.length; i7++) {
                double[] dArr3 = this.weightAttribute[i6];
                int i8 = i7;
                dArr3[i8] = dArr3[i8] + (value * ((d * dArr[i7][i6]) - (value2 * this.weightAttribute[i6][i7])));
            }
            double[] dArr4 = this.weightAttribute[i6];
            int length = this.ensemble.length;
            dArr4[length] = dArr4[length] + (value * d);
            i6++;
        }
    }

    public double predictionPruning(double[][] dArr, int[] iArr, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < dArr.length - 1; i2++) {
            d += this.weightAttribute[i][iArr[i2]] * dArr[i2][i];
        }
        return 1.0d / (1.0d + Math.exp(-(d + this.weightAttribute[i][dArr.length - 1])));
    }

    public double prediction(double[][] dArr, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < dArr.length - 1; i2++) {
            d += this.weightAttribute[i][i2] * dArr[i2][i];
        }
        return 1.0d / (1.0d + Math.exp(-(d + this.weightAttribute[i][dArr.length - 1])));
    }

    public double[] getVotesForInstancePerceptron(double[][] dArr, int[] iArr, int i) {
        double[] dArr2 = new double[i];
        if (!this.reset) {
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                dArr2[i2] = predictionPruning(dArr, iArr, i2);
            }
        }
        return dArr2;
    }
}
