package moa.classifiers.rules.multilabel.functions;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.rules.core.Utils;
import moa.core.Measurement;

/* loaded from: input_file:moa/classifiers/rules/multilabel/functions/StackedPredictor.class */
public class StackedPredictor extends AbstractMultiLabelLearner implements MultiTargetRegressor, AMRulesFunction {
    private static final long serialVersionUID = 1;
    private final double SD_THRESHOLD = 1.0E-7d;
    public FlagOption constantLearningRatioDecayOption = new FlagOption("learningRatio_Decay_set_constant", 'd', "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
    public FloatOption learningRatioOption = new FloatOption("learningRatio", 'l', "Learning Ratio to use for training the 1st layer.", 0.025d);
    public FloatOption learningRatio2ndLayerOption = new FloatOption("learningRatio2ndLayer", 'n', "Learning Ratio to use in the second layer.", 0.001d);
    public FloatOption learningRateDecayOption = new FloatOption("learningRateDecay", 'm', " Learning Rate decay to use for training the 1st layer.", 0.001d);
    public FlagOption skipStackingOption = new FlagOption("skipStackingOption", 's', "Predicts the outputs of the first layer (no dependence among output is computed)");
    public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.", 1);
    public FlagOption printWeightsOption = new FlagOption("printWeights", 'p', "Outputs the 2nd layer weights as measurements.");
    private boolean hasStarted;
    private double count;
    private double[] inAttrSum;
    private double[] inAttrSquaredSum;
    private double[] outAttrSum;
    private double[] outAttrSquaredSum;
    private double[][] layer1Weights;
    private double[][] layer2Weights;
    double currentLearningRate;
    LinkedList<Integer> numericIndices;

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.rules.multilabel.functions.AMRulesFunction
    public void resetWithMemory() {
        this.currentLearningRate = this.learningRatioOption.getValue();
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        int numOutputAttributes = multiLabelInstance.numOutputAttributes();
        if (!this.hasStarted) {
            this.hasStarted = true;
            this.numericIndices = new LinkedList<>();
            for (int i = 0; i < multiLabelInstance.numInputAttributes(); i++) {
                if (multiLabelInstance.inputAttribute(i).isNumeric()) {
                    this.numericIndices.add(Integer.valueOf(i));
                }
            }
            int size = this.numericIndices.size();
            this.inAttrSum = new double[size];
            this.inAttrSquaredSum = new double[size];
            this.outAttrSum = new double[numOutputAttributes];
            this.outAttrSquaredSum = new double[numOutputAttributes];
            this.layer1Weights = new double[size + 1][numOutputAttributes];
            this.layer2Weights = new double[numOutputAttributes + 1][numOutputAttributes];
            for (int i2 = 0; i2 < numOutputAttributes; i2++) {
                for (int i3 = 0; i3 < size + 1; i3++) {
                    this.layer1Weights[i3][i2] = (2.0d * this.classifierRandom.nextDouble()) - 1.0d;
                }
                this.layer2Weights[i2][i2] = 1.0d;
            }
        }
        int size2 = this.numericIndices.size();
        double weight = multiLabelInstance.weight();
        this.count += weight;
        Iterator<Integer> it = this.numericIndices.iterator();
        int i4 = 0;
        while (it.hasNext()) {
            double valueInputAttribute = multiLabelInstance.valueInputAttribute(it.next().intValue());
            double[] dArr = this.inAttrSum;
            int i5 = i4;
            dArr[i5] = dArr[i5] + (valueInputAttribute * weight);
            double[] dArr2 = this.inAttrSquaredSum;
            int i6 = i4;
            dArr2[i6] = dArr2[i6] + (valueInputAttribute * valueInputAttribute * weight);
            i4++;
        }
        for (int i7 = 0; i7 < numOutputAttributes; i7++) {
            double valueOutputAttribute = multiLabelInstance.valueOutputAttribute(i7);
            double[] dArr3 = this.outAttrSum;
            int i8 = i7;
            dArr3[i8] = dArr3[i8] + (valueOutputAttribute * weight);
            double[] dArr4 = this.outAttrSquaredSum;
            int i9 = i7;
            dArr4[i9] = dArr4[i9] + (valueOutputAttribute * valueOutputAttribute * weight);
        }
        double[] normalizedInput = getNormalizedInput(multiLabelInstance);
        double[] predict1stLayer = predict1stLayer(normalizedInput);
        double[] predict2ndLayer = this.skipStackingOption.isSet() ? null : predict2ndLayer(predict1stLayer);
        if (!this.constantLearningRatioDecayOption.isSet()) {
            this.currentLearningRate = this.learningRatioOption.getValue() / (1.0d + (this.count * this.learningRateDecayOption.getValue()));
        }
        double[] normalizedOutput = getNormalizedOutput(multiLabelInstance);
        for (int i10 = 0; i10 < numOutputAttributes; i10++) {
            double d = normalizedOutput[i10] - predict1stLayer[i10];
            double d2 = 0.0d;
            for (int i11 = 0; i11 < size2; i11++) {
                double[] dArr5 = this.layer1Weights[i11];
                int i12 = i10;
                dArr5[i12] = dArr5[i12] + (this.currentLearningRate * d * normalizedInput[i11] * multiLabelInstance.weight());
                d2 += Math.abs(this.layer1Weights[i11][i10]);
            }
            double[] dArr6 = this.layer1Weights[size2];
            int i13 = i10;
            dArr6[i13] = dArr6[i13] + (this.currentLearningRate * d * multiLabelInstance.weight());
            double abs = d2 + Math.abs(this.layer1Weights[size2][i10]);
            if (abs > size2) {
                for (int i14 = 0; i14 < size2 + 1; i14++) {
                    double[] dArr7 = this.layer1Weights[i14];
                    int i15 = i10;
                    dArr7[i15] = dArr7[i15] / abs;
                }
            }
        }
        if (this.skipStackingOption.isSet()) {
            return;
        }
        double value = this.learningRatio2ndLayerOption.getValue();
        for (int i16 = 0; i16 < numOutputAttributes; i16++) {
            double d3 = normalizedOutput[i16] - predict2ndLayer[i16];
            double d4 = 0.0d;
            for (int i17 = 0; i17 < numOutputAttributes; i17++) {
                double[] dArr8 = this.layer2Weights[i17];
                int i18 = i16;
                dArr8[i18] = dArr8[i18] + (value * d3 * predict1stLayer[i17] * multiLabelInstance.weight());
                d4 += Math.abs(this.layer2Weights[i17][i16]);
            }
            double[] dArr9 = this.layer2Weights[numOutputAttributes];
            int i19 = i16;
            dArr9[i19] = dArr9[i19] + (value * d3 * multiLabelInstance.weight());
            double abs2 = d4 + Math.abs(this.layer2Weights[numOutputAttributes][i16]);
            if (abs2 > numOutputAttributes) {
                for (int i20 = 0; i20 < numOutputAttributes + 1; i20++) {
                    double[] dArr10 = this.layer2Weights[i20];
                    int i21 = i16;
                    dArr10[i21] = dArr10[i21] / abs2;
                }
            }
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        MultiLabelPrediction multiLabelPrediction = null;
        if (this.hasStarted) {
            int length = this.outAttrSum.length;
            multiLabelPrediction = new MultiLabelPrediction(length);
            double[] predict1stLayer = predict1stLayer(getNormalizedInput(multiLabelInstance));
            double[] denormalizedOutput = !this.skipStackingOption.isSet() ? getDenormalizedOutput(predict2ndLayer(predict1stLayer)) : getDenormalizedOutput(predict1stLayer);
            for (int i = 0; i < length; i++) {
                multiLabelPrediction.setVotes(i, new double[]{denormalizedOutput[i]});
            }
        }
        return multiLabelPrediction;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.hasStarted = false;
        this.count = 0.0d;
        this.inAttrSum = null;
        this.inAttrSquaredSum = null;
        this.outAttrSum = null;
        this.outAttrSquaredSum = null;
        this.layer1Weights = (double[][]) null;
        this.layer2Weights = (double[][]) null;
        this.numericIndices = null;
        this.currentLearningRate = this.learningRatioOption.getValue();
        this.classifierRandom = new Random();
        this.classifierRandom.setSeed(this.randomSeedOption.getValue());
    }

    protected double[] getNormalizedInput(MultiLabelInstance multiLabelInstance) {
        double[] dArr = new double[this.numericIndices.size()];
        Iterator<Integer> it = this.numericIndices.iterator();
        int i = 0;
        while (it.hasNext()) {
            double d = this.inAttrSum[i] / this.count;
            double computeSD = Utils.computeSD(this.inAttrSquaredSum[i], this.inAttrSum[i], this.count);
            dArr[i] = multiLabelInstance.valueInputAttribute(it.next().intValue()) - d;
            if (computeSD > 1.0E-7d) {
                int i2 = i;
                dArr[i2] = dArr[i2] / computeSD;
            }
            i++;
        }
        return dArr;
    }

    protected double[] getNormalizedOutput(MultiLabelInstance multiLabelInstance) {
        int numOutputAttributes = multiLabelInstance.numOutputAttributes();
        double[] dArr = new double[numOutputAttributes];
        for (int i = 0; i < numOutputAttributes; i++) {
            double d = this.outAttrSum[i] / this.count;
            double computeSD = Utils.computeSD(this.outAttrSquaredSum[i], this.outAttrSum[i], this.count);
            dArr[i] = multiLabelInstance.valueOutputAttribute(i) - d;
            if (computeSD > 1.0E-7d) {
                int i2 = i;
                dArr[i2] = dArr[i2] / computeSD;
            }
        }
        return dArr;
    }

    protected double[] getDenormalizedOutput(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            double d = this.outAttrSum[i] / this.count;
            double computeSD = Utils.computeSD(this.outAttrSquaredSum[i], this.outAttrSum[i], this.count);
            if (computeSD > 1.0E-7d) {
                dArr2[i] = (dArr[i] * computeSD) + d;
            } else {
                dArr2[i] = dArr[i] + d;
            }
        }
        return dArr2;
    }

    private double[] predict1stLayer(double[] dArr) {
        int size = this.numericIndices.size();
        int length = this.outAttrSum.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (dArr[i2] * this.layer1Weights[i2][i]);
            }
            int i4 = i;
            dArr2[i4] = dArr2[i4] + this.layer1Weights[size][i];
        }
        return dArr2;
    }

    private double[] predict2ndLayer(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (dArr[i2] * this.layer2Weights[i2][i]);
            }
            int i4 = i;
            dArr2[i4] = dArr2[i4] + this.layer2Weights[length][i];
        }
        return dArr2;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = null;
        if (this.printWeightsOption.isSet()) {
            int length = this.layer2Weights.length;
            measurementArr = new Measurement[length * (length - 1)];
            int i = 0;
            for (int i2 = 0; i2 < length - 1; i2++) {
                for (int i3 = 0; i3 < length - 1; i3++) {
                    int i4 = i;
                    i++;
                    measurementArr[i4] = new Measurement("W Out" + (i3 + 1) + ": Out" + (i2 + 1), this.layer2Weights[i3][i2]);
                }
                int i5 = i;
                i++;
                measurementArr[i5] = new Measurement("W Bias: Out" + (i2 + 1), this.layer2Weights[length - 1][i2]);
            }
        }
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.rules.multilabel.functions.AMRulesFunction
    public void selectOutputsToLearn(int[] iArr) {
        int length = iArr.length;
        double[] dArr = new double[length];
        double[] dArr2 = new double[length];
        int length2 = this.layer1Weights.length;
        double[][] dArr3 = new double[length2][length];
        double[][] dArr4 = new double[length2][length];
        int length3 = this.layer2Weights.length - 1;
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            dArr[i] = this.outAttrSum[i2];
            dArr2[i] = this.outAttrSquaredSum[i2];
            for (int i3 = 0; i3 < length2; i3++) {
                dArr3[i3][i] = this.layer1Weights[i3][i2];
            }
            for (int i4 = 0; i4 < length; i4++) {
                dArr4[i4][i] = this.layer2Weights[iArr[i4]][i2];
            }
            dArr4[length][i] = this.layer2Weights[length3][i2];
        }
        this.outAttrSum = dArr;
        this.outAttrSquaredSum = dArr2;
        this.layer1Weights = dArr3;
        this.layer2Weights = dArr4;
    }
}
