package smile.projection;

import smile.math.Math;

/* loaded from: input_file:smile/projection/GHA.class */
public class GHA implements Projection<double[]> {
    private int p;
    private int n;
    private double r;
    private double[][] projection;
    private double[] y;
    private double[] wy;

    public GHA(int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid dimension of input space: " + i);
        }
        if (i2 < 1 || i2 > i) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + i2);
        }
        this.n = i;
        this.p = i2;
        this.r = d;
        this.y = new double[i2];
        this.wy = new double[i];
        this.projection = new double[i2][i];
        for (int i3 = 0; i3 < i2; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                this.projection[i3][i4] = 0.1d * Math.random();
            }
        }
    }

    public GHA(double[][] dArr, double d) {
        this.p = dArr.length;
        this.n = dArr[0].length;
        this.r = d;
        this.y = new double[this.p];
        this.wy = new double[this.n];
        this.projection = dArr;
    }

    public double[][] getProjection() {
        return this.projection;
    }

    public double getLearningRate() {
        return this.r;
    }

    public void setLearningRate(double d) {
        this.r = d;
    }

    @Override // smile.projection.Projection
    public double[] project(double[] dArr) {
        if (dArr.length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.n)));
        }
        double[] dArr2 = new double[this.p];
        Math.ax(this.projection, dArr, dArr2);
        return dArr2;
    }

    @Override // smile.projection.Projection
    public double[][] project(double[][] dArr) {
        if (dArr[0].length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr[0].length), Integer.valueOf(this.n)));
        }
        double[][] dArr2 = new double[dArr.length][this.p];
        for (int i = 0; i < dArr.length; i++) {
            Math.ax(this.projection, dArr[i], dArr2[i]);
        }
        return dArr2;
    }

    public double learn(double[] dArr) {
        if (dArr.length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.n)));
        }
        Math.ax(this.projection, dArr, this.y);
        for (int i = 0; i < this.p; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                double d = dArr[i2];
                for (int i3 = 0; i3 <= i; i3++) {
                    d -= this.projection[i3][i2] * this.y[i3];
                }
                double[] dArr2 = this.projection[i];
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (this.r * this.y[i] * d);
                if (Double.isInfinite(this.projection[i][i2])) {
                    throw new IllegalStateException("GHA lost convergence. Lower learning rate?");
                }
            }
        }
        Math.ax(this.projection, dArr, this.y);
        Math.atx(this.projection, this.y, this.wy);
        return Math.squaredDistance(dArr, this.wy);
    }
}
