package org.nd4j.linalg.dimensionalityreduction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/dimensionalityreduction/RandomProjection.class */
public class RandomProjection {
    private int components;
    private Random rng;
    private double eps;
    private boolean autoMode;
    private long[] projectionMatrixShape;
    private INDArray _projectionMatrix;

    public RandomProjection(double d, Random random) {
        this.rng = random;
        this.eps = d;
        this.autoMode = true;
    }

    public RandomProjection(double d) {
        this(d, Nd4j.getRandom());
    }

    public RandomProjection(int i, Random random) {
        this.rng = random;
        this.components = i;
        this.autoMode = false;
    }

    public RandomProjection(int i) {
        this(i, Nd4j.getRandom());
    }

    public static List<Integer> johnsonLindenstraussMinDim(int[] iArr, double... dArr) {
        if (Boolean.valueOf(iArr == null || iArr.length == 0 || dArr == null || dArr.length == 0).booleanValue()) {
            throw new IllegalArgumentException("Johnson-Lindenstrauss dimension estimation requires > 0 components and at least a relative error");
        }
        for (double d : dArr) {
            if (d <= 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("A relative error should be in ]0, 1[");
            }
        }
        ArrayList arrayList = new ArrayList(iArr.length * dArr.length);
        for (double d2 : dArr) {
            double pow = (Math.pow(d2, 2.0d) / 2.0d) - (Math.pow(d2, 3.0d) / 3.0d);
            for (int i : iArr) {
                arrayList.add(Integer.valueOf((int) ((4.0d * Math.log(i)) / pow)));
            }
        }
        return arrayList;
    }

    public static List<Long> johnsonLindenstraussMinDim(long[] jArr, double... dArr) {
        if (Boolean.valueOf(jArr == null || jArr.length == 0 || dArr == null || dArr.length == 0).booleanValue()) {
            throw new IllegalArgumentException("Johnson-Lindenstrauss dimension estimation requires > 0 components and at least a relative error");
        }
        for (double d : dArr) {
            if (d <= 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("A relative error should be in ]0, 1[");
            }
        }
        ArrayList arrayList = new ArrayList(jArr.length * dArr.length);
        for (double d2 : dArr) {
            double pow = (Math.pow(d2, 2.0d) / 2.0d) - (Math.pow(d2, 3.0d) / 3.0d);
            for (long j : jArr) {
                arrayList.add(Long.valueOf((long) ((4.0d * Math.log(j)) / pow)));
            }
        }
        return arrayList;
    }

    public static List<Integer> johnsonLindenStraussMinDim(int i, double... dArr) {
        return johnsonLindenstraussMinDim(new int[]{i}, dArr);
    }

    public static List<Long> johnsonLindenStraussMinDim(long j, double... dArr) {
        return johnsonLindenstraussMinDim(new long[]{j}, dArr);
    }

    private INDArray gaussianRandomMatrix(long[] jArr, Random random) {
        Nd4j.checkShapeValues(jArr);
        INDArray create = Nd4j.create(jArr);
        Nd4j.getExecutioner().exec(new GaussianDistribution(create, 0.0d, 1.0d / Math.sqrt(jArr[0])), random);
        return create;
    }

    private INDArray getProjectionMatrix(long[] jArr, Random random) {
        if (!Arrays.equals(this.projectionMatrixShape, jArr) || this._projectionMatrix == null) {
            this._projectionMatrix = gaussianRandomMatrix(jArr, random);
        }
        return this._projectionMatrix;
    }

    private static int[] targetShape(int[] iArr, double d, int i, boolean z) {
        int i2 = i;
        if (z) {
            i2 = johnsonLindenStraussMinDim(iArr[0], d).get(0).intValue();
        }
        if (!z || (i2 > 0 && i2 <= iArr[1])) {
            return new int[]{iArr[1], i2};
        }
        throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", Integer.valueOf(i2)));
    }

    private static long[] targetShape(long[] jArr, double d, int i, boolean z) {
        long j = i;
        if (z) {
            j = johnsonLindenStraussMinDim(jArr[0], d).get(0).longValue();
        }
        if (!z || (j > 0 && j <= jArr[1])) {
            return new long[]{jArr[1], j};
        }
        throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", Long.valueOf(j)));
    }

    public static long[] targetShape(INDArray iNDArray, double d) {
        return targetShape(iNDArray.shape(), d, -1, true);
    }

    protected static long[] targetShape(INDArray iNDArray, int i) {
        return targetShape(iNDArray.shape(), -1.0d, i, false);
    }

    public INDArray project(INDArray iNDArray) {
        return iNDArray.mmul(getProjectionMatrix(targetShape(iNDArray.shape(), this.eps, this.components, this.autoMode), this.rng));
    }

    public INDArray project(INDArray iNDArray, INDArray iNDArray2) {
        return iNDArray.mmuli(getProjectionMatrix(targetShape(iNDArray.shape(), this.eps, this.components, this.autoMode), this.rng), iNDArray2);
    }

    public INDArray projecti(INDArray iNDArray) {
        return iNDArray.mmuli(getProjectionMatrix(targetShape(iNDArray.shape(), this.eps, this.components, this.autoMode), this.rng));
    }

    public INDArray projecti(INDArray iNDArray, INDArray iNDArray2) {
        return iNDArray.mmuli(getProjectionMatrix(targetShape(iNDArray.shape(), this.eps, this.components, this.autoMode), this.rng), iNDArray2);
    }
}
