package com.googlecode.clearnlp.classification.algorithm;

import com.carrotsearch.hppc.IntArrayList;
import com.googlecode.clearnlp.classification.prediction.IntPrediction;
import com.googlecode.clearnlp.classification.train.AbstractTrainSpace;
import com.googlecode.clearnlp.util.UTArray;
import com.googlecode.clearnlp.util.triple.Triple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/googlecode/clearnlp/classification/algorithm/AdaGrad.class */
public class AdaGrad extends AbstractAlgorithm {
    protected int n_iter;
    protected Random r_rand;
    protected double d_alpha;
    protected double d_rho;

    public AdaGrad(int i, double d, double d2, Random random) {
        this.n_iter = i;
        this.r_rand = random;
        this.d_alpha = d;
        this.d_rho = d2;
    }

    @Override // com.googlecode.clearnlp.classification.algorithm.AbstractAlgorithm
    public double[] getWeight(AbstractTrainSpace abstractTrainSpace, int i) {
        double[] dArr = new double[abstractTrainSpace.getFeatureSize() * abstractTrainSpace.getLabelSize()];
        updateWeight(abstractTrainSpace, dArr);
        return dArr;
    }

    public void updateWeight(AbstractTrainSpace abstractTrainSpace) {
        updateWeight(abstractTrainSpace, abstractTrainSpace.getModel().getWeights());
    }

    public void updateWeight(AbstractTrainSpace abstractTrainSpace, double[] dArr) {
        int featureSize = abstractTrainSpace.getFeatureSize();
        int labelSize = abstractTrainSpace.getLabelSize();
        int instanceSize = abstractTrainSpace.getInstanceSize();
        double[] dArr2 = new double[featureSize * labelSize];
        IntArrayList ys = abstractTrainSpace.getYs();
        ArrayList<int[]> xs = abstractTrainSpace.getXs();
        ArrayList<double[]> vs = abstractTrainSpace.getVs();
        double[] dArr3 = null;
        for (int i = 0; i < this.n_iter; i++) {
            int[] shuffledIndices = getShuffledIndices(instanceSize);
            Arrays.fill(dArr2, 0.0d);
            int i2 = 0;
            for (int i3 = 0; i3 < instanceSize; i3++) {
                int i4 = ys.get(shuffledIndices[i3]);
                int[] iArr = xs.get(shuffledIndices[i3]);
                if (abstractTrainSpace.hasWeight()) {
                    dArr3 = vs.get(shuffledIndices[i3]);
                }
                Triple<IntPrediction, IntPrediction, IntPrediction> predictions = getPredictions(labelSize, i4, iArr, dArr3, dArr);
                IntPrediction intPrediction = predictions.o1;
                IntPrediction intPrediction2 = predictions.o2;
                if (intPrediction.label != i4) {
                    updateCounts(labelSize, dArr2, i4, intPrediction.label, iArr, dArr3);
                    updateWeights(labelSize, dArr2, i4, intPrediction.label, iArr, dArr3, dArr);
                } else if (intPrediction.score - intPrediction2.score < 1.0d) {
                    updateCounts(labelSize, dArr2, i4, intPrediction2.label, iArr, dArr3);
                    updateWeights(labelSize, dArr2, i4, intPrediction2.label, iArr, dArr3, dArr);
                } else {
                    i2++;
                }
            }
            System.out.printf("- %3d: acc = %7.4f\n", Integer.valueOf(i + 1), Double.valueOf((100.0d * i2) / instanceSize));
        }
    }

    private int[] getShuffledIndices(int i) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = i2;
        }
        for (int i3 = 0; i3 < i; i3++) {
            UTArray.swap(iArr, i3, i3 + this.r_rand.nextInt(i - i3));
        }
        return iArr;
    }

    protected IntPrediction getPrediction(int i, int i2, int[] iArr, double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[i];
        int length = iArr.length;
        Arrays.fill(dArr3, 1.0d);
        dArr3[i2] = 0.0d;
        if (dArr != null) {
            for (int i3 = 0; i3 < length; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + (dArr2[getWeightIndex(i, i4, iArr[i3])] * dArr[i3]);
                }
            }
        } else {
            for (int i6 : iArr) {
                for (int i7 = 0; i7 < i; i7++) {
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] + dArr2[getWeightIndex(i, i7, i6)];
                }
            }
        }
        IntPrediction intPrediction = new IntPrediction(0, dArr3[0]);
        for (int i9 = 1; i9 < i; i9++) {
            if (intPrediction.score < dArr3[i9]) {
                intPrediction.set(i9, dArr3[i9]);
            }
        }
        return intPrediction;
    }

    protected Triple<IntPrediction, IntPrediction, IntPrediction> getPredictions(int i, int i2, int[] iArr, double[] dArr, double[] dArr2) {
        IntPrediction intPrediction;
        IntPrediction intPrediction2;
        double[] dArr3 = new double[i];
        int length = iArr.length;
        if (dArr != null) {
            for (int i3 = 0; i3 < length; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + (dArr2[getWeightIndex(i, i4, iArr[i3])] * dArr[i3]);
                }
            }
        } else {
            for (int i6 : iArr) {
                for (int i7 = 0; i7 < i; i7++) {
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] + dArr2[getWeightIndex(i, i7, i6)];
                }
            }
        }
        if (dArr3[0] > dArr3[1]) {
            intPrediction = new IntPrediction(0, dArr3[0]);
            intPrediction2 = new IntPrediction(1, dArr3[1]);
        } else {
            intPrediction = new IntPrediction(1, dArr3[1]);
            intPrediction2 = new IntPrediction(0, dArr3[0]);
        }
        for (int i9 = 2; i9 < i; i9++) {
            if (intPrediction.score < dArr3[i9]) {
                intPrediction2.set(intPrediction.label, intPrediction.score);
                intPrediction.set(i9, dArr3[i9]);
            } else if (intPrediction2.score < dArr3[i9]) {
                intPrediction2.set(i9, dArr3[i9]);
            }
        }
        return new Triple<>(intPrediction, intPrediction2, new IntPrediction(i2, dArr3[i2]));
    }

    protected void updateCounts(int i, double[] dArr, int i2, int i3, int[] iArr, double[] dArr2) {
        int length = iArr.length;
        if (dArr2 == null) {
            for (int i4 = 0; i4 < length; i4++) {
                int weightIndex = getWeightIndex(i, i2, iArr[i4]);
                dArr[weightIndex] = dArr[weightIndex] + 1.0d;
                int weightIndex2 = getWeightIndex(i, i3, iArr[i4]);
                dArr[weightIndex2] = dArr[weightIndex2] + 1.0d;
            }
            return;
        }
        for (int i5 = 0; i5 < length; i5++) {
            double d = dArr2[i5] * dArr2[i5];
            int weightIndex3 = getWeightIndex(i, i2, iArr[i5]);
            dArr[weightIndex3] = dArr[weightIndex3] + d;
            int weightIndex4 = getWeightIndex(i, i3, iArr[i5]);
            dArr[weightIndex4] = dArr[weightIndex4] + d;
        }
    }

    protected void updateWeights(int i, double[] dArr, int i2, int i3, int[] iArr, double[] dArr2, double[] dArr3) {
        int length = iArr.length;
        if (dArr2 == null) {
            for (int i4 : iArr) {
                int weightIndex = getWeightIndex(i, i2, i4);
                dArr3[weightIndex] = dArr3[weightIndex] + getUpdate(i, dArr, i2, i4);
                int weightIndex2 = getWeightIndex(i, i3, i4);
                dArr3[weightIndex2] = dArr3[weightIndex2] - getUpdate(i, dArr, i3, i4);
            }
            return;
        }
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = iArr[i5];
            double d = dArr2[i5];
            int weightIndex3 = getWeightIndex(i, i2, i6);
            dArr3[weightIndex3] = dArr3[weightIndex3] + (getUpdate(i, dArr, i2, i6) * d);
            int weightIndex4 = getWeightIndex(i, i3, i6);
            dArr3[weightIndex4] = dArr3[weightIndex4] - (getUpdate(i, dArr, i3, i6) * d);
        }
    }

    protected double getUpdate(int i, double[] dArr, int i2, int i3) {
        return this.d_alpha / (this.d_rho + Math.sqrt(dArr[getWeightIndex(i, i2, i3)]));
    }
}
