package com.googlecode.clearnlp.classification.algorithm;

import com.carrotsearch.hppc.IntArrayList;
import com.googlecode.clearnlp.classification.train.AbstractTrainSpace;
import com.googlecode.clearnlp.util.UTArray;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/googlecode/clearnlp/classification/algorithm/AdaGradLR.class */
public class AdaGradLR extends AbstractAlgorithm {
    protected int n_iter;
    protected Random r_rand;
    protected double d_alpha;
    protected double d_rho;

    public AdaGradLR(int i, double d, double d2, Random random) {
        this.n_iter = i;
        this.r_rand = random;
        this.d_alpha = d;
        this.d_rho = d2;
    }

    @Override // com.googlecode.clearnlp.classification.algorithm.AbstractAlgorithm
    public double[] getWeight(AbstractTrainSpace abstractTrainSpace, int i) {
        double[] dArr = new double[abstractTrainSpace.getFeatureSize() * abstractTrainSpace.getLabelSize()];
        updateWeight(abstractTrainSpace, dArr);
        return dArr;
    }

    public void updateWeight(AbstractTrainSpace abstractTrainSpace) {
        updateWeight(abstractTrainSpace, abstractTrainSpace.getModel().getWeights());
    }

    public void updateWeight(AbstractTrainSpace abstractTrainSpace, double[] dArr) {
        int featureSize = abstractTrainSpace.getFeatureSize();
        int labelSize = abstractTrainSpace.getLabelSize();
        int instanceSize = abstractTrainSpace.getInstanceSize();
        double[] dArr2 = new double[featureSize * labelSize];
        IntArrayList ys = abstractTrainSpace.getYs();
        ArrayList<int[]> xs = abstractTrainSpace.getXs();
        ArrayList<double[]> vs = abstractTrainSpace.getVs();
        double[] dArr3 = null;
        for (int i = 0; i < this.n_iter; i++) {
            int[] shuffledIndices = getShuffledIndices(instanceSize);
            Arrays.fill(dArr2, 0.0d);
            for (int i2 = 0; i2 < instanceSize; i2++) {
                int i3 = ys.get(shuffledIndices[i2]);
                int[] iArr = xs.get(shuffledIndices[i2]);
                if (abstractTrainSpace.hasWeight()) {
                    dArr3 = vs.get(shuffledIndices[i2]);
                }
                double[] gradients = getGradients(labelSize, i3, iArr, dArr3, dArr);
                updateCounts(labelSize, dArr2, gradients, iArr, dArr3);
                updateWeights(labelSize, dArr2, gradients, iArr, dArr3, dArr);
            }
        }
    }

    protected int[] getShuffledIndices(int i) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = i2;
        }
        for (int i3 = 0; i3 < i; i3++) {
            UTArray.swap(iArr, i3, i3 + this.r_rand.nextInt(i - i3));
        }
        return iArr;
    }

    protected double[] getGradients(int i, int i2, int[] iArr, double[] dArr, double[] dArr2) {
        double[] scores = getScores(i, iArr, dArr, dArr2);
        normalize(scores);
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = i3;
            scores[i4] = scores[i4] * (-1.0d);
        }
        scores[i2] = scores[i2] + 1.0d;
        return scores;
    }

    private double[] getScores(int i, int[] iArr, double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[i];
        int length = iArr.length;
        if (dArr != null) {
            for (int i2 = 0; i2 < length; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    int i4 = i3;
                    dArr3[i4] = dArr3[i4] + (dArr2[getWeightIndex(i, i3, iArr[i2])] * dArr[i2]);
                }
            }
        } else {
            for (int i5 : iArr) {
                for (int i6 = 0; i6 < i; i6++) {
                    int i7 = i6;
                    dArr3[i7] = dArr3[i7] + dArr2[getWeightIndex(i, i6, i5)];
                }
            }
        }
        return dArr3;
    }

    protected void updateCounts(int i, double[] dArr, double[] dArr2, int[] iArr, double[] dArr3) {
        int length = iArr.length;
        double[] dArr4 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr4[i2] = dArr2[i2] * dArr2[i2];
        }
        if (dArr3 == null) {
            for (int i3 : iArr) {
                for (int i4 = 0; i4 < i; i4++) {
                    int weightIndex = getWeightIndex(i, i4, i3);
                    dArr[weightIndex] = dArr[weightIndex] + dArr4[i4];
                }
            }
            return;
        }
        for (int i5 = 0; i5 < length; i5++) {
            double d = dArr3[i5] * dArr3[i5];
            for (int i6 = 0; i6 < i; i6++) {
                int weightIndex2 = getWeightIndex(i, i6, iArr[i5]);
                dArr[weightIndex2] = dArr[weightIndex2] + (d * dArr4[i6]);
            }
        }
    }

    protected void updateWeights(int i, double[] dArr, double[] dArr2, int[] iArr, double[] dArr3, double[] dArr4) {
        int length = iArr.length;
        if (dArr3 != null) {
            for (int i2 = 0; i2 < length; i2++) {
                for (int i3 = 0; i3 < i; i3++) {
                    int weightIndex = getWeightIndex(i, i3, iArr[i2]);
                    dArr4[weightIndex] = dArr4[weightIndex] + (getUpdate(i, dArr, i3, iArr[i2]) * dArr2[i3] * dArr3[i2]);
                }
            }
            return;
        }
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = 0; i5 < i; i5++) {
                int weightIndex2 = getWeightIndex(i, i5, iArr[i4]);
                dArr4[weightIndex2] = dArr4[weightIndex2] + (getUpdate(i, dArr, i5, iArr[i4]) * dArr2[i5]);
            }
        }
    }

    protected double getUpdate(int i, double[] dArr, int i2, int i3) {
        return this.d_alpha / (this.d_rho + Math.sqrt(dArr[getWeightIndex(i, i2, i3)]));
    }
}
