package org.deeplearning4j.util;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Counter;

/* loaded from: input_file:org/deeplearning4j/util/MathUtils.class */
public class MathUtils {
    public static double log2 = Math.log(2.0d);
    public static double SMALL = 1.0E-6d;

    public static double normalize(double d, double d2, double d3) {
        if (d3 < d2) {
            throw new IllegalArgumentException("Max must be greater than min");
        }
        return (d - d2) / (d3 - d2);
    }

    public static int clamp(int i, int i2, int i3) {
        if (i < i2) {
            i = i2;
        }
        if (i > i3) {
            i = i3;
        }
        return i;
    }

    public static int discretize(double d, double d2, double d3, int i) {
        return clamp((int) (i * normalize(d, d2, d3)), 0, i - 1);
    }

    public static long nextPowOf2(long j) {
        long j2 = j - 1;
        long j3 = j2 | (j2 >> 1);
        long j4 = j3 | (j3 >> 2);
        long j5 = j4 | (j4 >> 4);
        long j6 = j5 | (j5 >> 8);
        return (j6 | (j6 >> 16)) + 1;
    }

    public static int binomial(RandomGenerator randomGenerator, int i, double d) {
        if (d < 0.0d || d > 1.0d) {
            return 0;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (randomGenerator.nextDouble() < d) {
                i2++;
            }
        }
        return i2;
    }

    public static double uniform(Random random, double d, double d2) {
        return (random.nextDouble() * (d2 - d)) + d;
    }

    public static double correlation(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr3.length; i++) {
            dArr3[i] = dArr2[i] - dArr[i];
        }
        return 1.0d - (ssError(dArr3, dArr2) / ssTotal(dArr, dArr2));
    }

    public static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.pow(2.718281828459045d, -d));
    }

    public static double ssReg(double[] dArr, double[] dArr2) {
        double sum = sum(dArr2) / dArr2.length;
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.pow(d2 - sum, 2.0d);
        }
        return d;
    }

    public static double ssError(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.pow(dArr2[i] - dArr[i], 2.0d);
        }
        return d;
    }

    public static double stringSimilarity(String... strArr) {
        if (strArr == null) {
            return 0.0d;
        }
        Counter counter = new Counter();
        Counter counter2 = new Counter();
        for (int i = 0; i < strArr[0].length(); i++) {
            counter.incrementCount(String.valueOf(strArr[0].charAt(i)), 1.0d);
        }
        for (int i2 = 0; i2 < strArr[1].length(); i2++) {
            counter2.incrementCount(String.valueOf(strArr[1].charAt(i2)), 1.0d);
        }
        Set<String> keySet = counter.keySet();
        Set<String> keySet2 = counter2.keySet();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (String str : SetUtils.intersection(keySet, keySet2)) {
            d += counter.getCount(str) * counter2.getCount(str);
        }
        for (String str2 : keySet) {
            d2 += counter.getCount(str2) * counter.getCount(str2);
        }
        for (String str3 : keySet2) {
            d3 += counter2.getCount(str3) * counter2.getCount(str3);
        }
        return d / Math.sqrt(d2 * d3);
    }

    public static double vectorLength(double[] dArr) {
        double d = 0.0d;
        if (dArr == null) {
            return 0.0d;
        }
        for (double d2 : dArr) {
            d += Math.pow(d2, 2.0d);
        }
        return d;
    }

    public static double idf(double d, double d2) {
        if (d > 0.0d) {
            return Math.log10(d / d2);
        }
        return 0.0d;
    }

    public static double tf(int i) {
        if (i > 0) {
            return 1.0d + Math.log10(i);
        }
        return 0.0d;
    }

    public static double tfidf(double d, double d2) {
        return d * d2;
    }

    private static int charForLetter(char c) {
        char[] cArr = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'};
        for (int i = 0; i < cArr.length; i++) {
            if (cArr[i] == c) {
                return i;
            }
        }
        return -1;
    }

    public static double ssTotal(double[] dArr, double[] dArr2) {
        return ssReg(dArr, dArr2) + ssError(dArr, dArr2);
    }

    public static double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    public static double[] mergeCoords(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Sample sizes must be the same for each data applyTransformToDestination.");
        }
        double[] dArr3 = new double[dArr.length + dArr2.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i];
            dArr3[i + 1] = dArr2[i];
        }
        return dArr3;
    }

    public static List<Double> mergeCoords(List<Double> list, List<Double> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Sample sizes must be the same for each data applyTransformToDestination.");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(list.get(i));
            arrayList.add(list2.get(i));
        }
        return arrayList;
    }

    public static double[] weightsFor(List<Double> list) {
        List<double[]> coordSplit = coordSplit(list);
        double[] dArr = coordSplit.get(0);
        double[] dArr2 = coordSplit.get(1);
        double sum = sum(dArr) / dArr.length;
        double sum2 = sum(dArr2) / dArr2.length;
        double sumOfMeanDifferences = sumOfMeanDifferences(dArr, dArr2) / sumOfMeanDifferencesOnePoint(dArr);
        double d = sum2 - (sumOfMeanDifferences * sum);
        double[] dArr3 = new double[list.size()];
        dArr3[0] = d;
        dArr3[1] = sumOfMeanDifferences;
        return dArr3;
    }

    public static double squaredLoss(double[] dArr, double[] dArr2, double d, double d2) {
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d3 += Math.pow(dArr2[i] - ((d2 * dArr[i]) + d), 2.0d);
        }
        return d3;
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    public static double w_1(double[] dArr, double[] dArr2, int i) {
        return ((i * sumOfProducts(new double[]{dArr, dArr2})) - (sum(dArr) * sum(dArr2))) / ((i * sumOfSquares(dArr)) - Math.pow(sum(dArr), 2.0d));
    }

    public static double w_0(double[] dArr, double[] dArr2, int i) {
        return (sum(dArr2) - (w_1(dArr, dArr2, i) * sum(dArr))) / i;
    }

    public static double[] weightsFor(double[] dArr) {
        List<double[]> coordSplit = coordSplit(dArr);
        double[] dArr2 = coordSplit.get(0);
        double[] dArr3 = coordSplit.get(1);
        double sum = sum(dArr2) / dArr2.length;
        double sum2 = sum(dArr3) / dArr3.length;
        double sumOfMeanDifferences = sumOfMeanDifferences(dArr2, dArr3) / sumOfMeanDifferencesOnePoint(dArr2);
        double d = sum2 - (sumOfMeanDifferences * sum);
        double[] dArr4 = new double[dArr.length];
        dArr4[0] = d;
        dArr4[1] = sumOfMeanDifferences;
        return dArr4;
    }

    public static double errorFor(double d, double d2) {
        return d - d2;
    }

    public static double sumOfMeanDifferences(double[] dArr, double[] dArr2) {
        double sum = sum(dArr) / dArr.length;
        double sum2 = sum(dArr2) / dArr2.length;
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += (dArr[i] - sum) * (dArr2[i] - sum2);
        }
        return d;
    }

    public static double sumOfMeanDifferencesOnePoint(double[] dArr) {
        double sum = sum(dArr) / dArr.length;
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.pow(d2 - sum, 2.0d);
        }
        return d;
    }

    public static double times(double[] dArr) {
        if (dArr == null || dArr.length == 0) {
            return 0.0d;
        }
        double d = 1.0d;
        for (double d2 : dArr) {
            d *= d2;
        }
        return d;
    }

    public static double sumOfProducts(double[]... dArr) {
        if (dArr == null || dArr.length < 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += times(column(i, dArr));
        }
        return d;
    }

    private static double[] column(int i, double[]... dArr) throws IllegalArgumentException {
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = dArr[i2][i];
        }
        return dArr2;
    }

    public static List<double[]> coordSplit(double[] dArr) {
        if (dArr == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        double[] dArr2 = new double[dArr.length / 2];
        double[] dArr3 = new double[dArr.length / 2];
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (i3 % 2 == 0) {
                int i4 = i;
                i++;
                dArr2[i4] = dArr[i3];
            } else {
                int i5 = i2;
                i2++;
                dArr3[i5] = dArr[i3];
            }
        }
        arrayList.add(dArr2);
        arrayList.add(dArr3);
        return arrayList;
    }

    public static List<List<Double>> partitionVariable(List<Double> list, int i) {
        int i2 = 0;
        ArrayList<List> arrayList = new ArrayList();
        while (i2 < list.size()) {
            List<Double> subList = list.subList(i2, i2 + i);
            i2 += i;
            arrayList.add(subList);
        }
        for (List list2 : arrayList) {
            if (list2.size() < i) {
                arrayList.remove(list2);
            }
        }
        return arrayList;
    }

    public static List<double[]> coordSplit(List<Double> list) {
        if (list == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        double[] dArr = new double[list.size() / 2];
        double[] dArr2 = new double[list.size() / 2];
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            if (i3 % 2 == 0) {
                int i4 = i;
                i++;
                dArr[i4] = list.get(i3).doubleValue();
            } else {
                int i5 = i2;
                i2++;
                dArr2[i5] = list.get(i3).doubleValue();
            }
        }
        arrayList.add(dArr);
        arrayList.add(dArr2);
        return arrayList;
    }

    public static double[] xVals(double[] dArr) {
        if (dArr == null) {
            return null;
        }
        double[] dArr2 = new double[dArr.length / 2];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 % 2 != 0) {
                int i3 = i;
                i++;
                dArr2[i3] = dArr[i2];
            }
        }
        return dArr2;
    }

    public static double[] yVals(double[] dArr) {
        double[] dArr2 = new double[dArr.length / 2];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 % 2 == 0) {
                int i3 = i;
                i++;
                dArr2[i3] = dArr[i2];
            }
        }
        return dArr2;
    }

    public static double sumOfSquares(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.pow(d2, 2.0d);
        }
        return d;
    }

    public static double determinationCoefficient(double[] dArr, double[] dArr2, int i) {
        return Math.pow(correlation(dArr, dArr2), 2.0d);
    }

    public static double log2(double d) {
        if (d == 0.0d) {
            return 0.0d;
        }
        return Math.log(d) / log2;
    }

    public double slope(double d, double d2, double d3, double d4) {
        return (d4 - d3) / (d2 - d);
    }

    public static double rootMeansSquaredError(double[] dArr, double[] dArr2) {
        double length = 1 / dArr.length;
        for (int i = 0; i < dArr.length; i++) {
            length += Math.pow(dArr[i] - dArr2[i], 2.0d);
        }
        return Math.sqrt(length);
    }

    public static double entropy(double[] dArr) {
        if (dArr == null || dArr.length < 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2 * Math.log(d2);
        }
        return d;
    }

    public static int kroneckerDelta(double d, double d2) {
        return d == d2 ? 1 : 0;
    }

    public static double adjustedrSquared(double d, int i, int i2) {
        return 1.0d - ((1.0d - d) * ((i2 - 1) / ((i2 - i) - 1)));
    }

    public static double[] normalizeToOne(double[] dArr) {
        normalize(dArr, sum(dArr));
        return dArr;
    }

    public static double min(double[] dArr) {
        double d = dArr[0];
        for (double d2 : dArr) {
            if (d2 < d) {
                d = d2;
            }
        }
        return d;
    }

    public static double max(double[] dArr) {
        double d = dArr[0];
        for (double d2 : dArr) {
            if (d2 > d) {
                d = d2;
            }
        }
        return d;
    }

    public static void normalize(double[] dArr, double d) {
        if (Double.isNaN(d)) {
            throw new IllegalArgumentException("Can't normalize array. Sum is NaN.");
        }
        if (d == 0.0d) {
            throw new IllegalArgumentException("Can't normalize array. Sum is zero.");
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    public static double[] logs2probs(double[] dArr) {
        double d = dArr[maxIndex(dArr)];
        double d2 = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i] - d);
            d2 += dArr2[i];
        }
        normalize(dArr2, d2);
        return dArr2;
    }

    public static double information(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += (-1.0d) * log2(d2) * d2;
        }
        return d;
    }

    public static int maxIndex(double[] dArr) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 == 0 || dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        return i;
    }

    public static double factorial(double d) {
        if (d == 1.0d || d == 0.0d) {
            return 1.0d;
        }
        double d2 = d;
        while (d2 > 0.0d) {
            d2 -= 1.0d;
            d *= d2 > 0.0d ? d2 : 1.0d;
        }
        return d;
    }

    public static double probToLogOdds(double d) {
        if (gr(d, 1.0d) || sm(d, 0.0d)) {
            throw new IllegalArgumentException("probToLogOdds: probability must be in [0,1] " + d);
        }
        double d2 = SMALL + ((1.0d - (2.0d * SMALL)) * d);
        return Math.log(d2 / (1.0d - d2));
    }

    public static int round(double d) {
        return d > 0.0d ? (int) (d + 0.5d) : -((int) (Math.abs(d) + 0.5d));
    }

    public static double permutation(double d, double d2) {
        return factorial(d) / factorial(d - d2);
    }

    public static double combination(double d, double d2) {
        return factorial(d) / (factorial(d2) * factorial(d - d2));
    }

    public static double hypotenuse(double d, double d2) {
        double d3;
        if (Math.abs(d) > Math.abs(d2)) {
            double d4 = d2 / d;
            d3 = Math.abs(d) * Math.sqrt(1.0d + (d4 * d4));
        } else if (d2 != 0.0d) {
            double d5 = d / d2;
            d3 = Math.abs(d2) * Math.sqrt(1.0d + (d5 * d5));
        } else {
            d3 = 0.0d;
        }
        return d3;
    }

    public static int probRound(double d, Random random) {
        if (d >= 0.0d) {
            double floor = Math.floor(d);
            return random.nextDouble() < d - floor ? ((int) floor) + 1 : (int) floor;
        }
        double floor2 = Math.floor(Math.abs(d));
        return random.nextDouble() < Math.abs(d) - floor2 ? -(((int) floor2) + 1) : -((int) floor2);
    }

    public static double roundDouble(double d, int i) {
        return Math.round(d * r0) / Math.pow(10.0d, i);
    }

    public static float roundFloat(float f, int i) {
        return Math.round(f * r0) / ((float) Math.pow(10.0d, i));
    }

    public static double bernoullis(double d, double d2, double d3) {
        return combination(d, d2) * Math.pow(d3, d2) * Math.pow(1.0d - d3, d - d2);
    }

    public static boolean sm(double d, double d2) {
        return d2 - d > SMALL;
    }

    public static boolean gr(double d, double d2) {
        return d - d2 > SMALL;
    }

    public static double[] fromString(String str, String str2) {
        String[] split = str.split(str2);
        double[] dArr = new double[split.length];
        for (int i = 0; i < split.length; i++) {
            dArr[i] = Double.parseDouble(split[i]);
        }
        return dArr;
    }

    public static double mean(double[] dArr) {
        double d = 0.0d;
        if (dArr.length == 0) {
            return 0.0d;
        }
        for (double d2 : dArr) {
            d += d2;
        }
        return d / dArr.length;
    }

    public CholeskyDecomposition choleskyFromMatrix(RealMatrix realMatrix) throws Exception {
        return new CholeskyDecomposition(realMatrix);
    }

    public static int toDecimal(String str) {
        long parseLong = Long.parseLong(str);
        while (parseLong > 0) {
            long j = parseLong % 10;
            parseLong /= 10;
            if (j != 0 && j != 1) {
                System.out.println("This is not a binary number.");
                System.out.println("Please try once again.");
                return -1;
            }
        }
        return Integer.parseInt(str, 2);
    }

    public static int distanceFinderZValue(double[] dArr) {
        StringBuilder sb = new StringBuilder();
        ArrayList arrayList = new ArrayList(dArr.length);
        for (double d : dArr) {
            arrayList.add(Integer.toBinaryString((int) d));
        }
        while (!arrayList.isEmpty()) {
            for (int i = 0; i < arrayList.size(); i++) {
                String str = (String) arrayList.get(i);
                if (str.isEmpty()) {
                    arrayList.remove(i);
                } else {
                    sb.append(str.charAt(0));
                    arrayList.set(i, str.substring(1));
                }
            }
        }
        return Integer.parseInt(sb.toString(), 2);
    }

    public static double euclideanDistance(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.pow(dArr2[i] - dArr[i], 2.0d);
        }
        return d;
    }

    public static double euclideanDistance(float[] fArr, float[] fArr2) {
        double d = 0.0d;
        for (int i = 0; i < fArr.length; i++) {
            d += Math.pow(fArr2[i] - fArr[i], 2.0d);
        }
        return d;
    }

    public static double[] generateUniform(int i) {
        double[] dArr = new double[i];
        Random random = new Random();
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = random.nextDouble();
        }
        return dArr;
    }

    public static double manhattanDistance(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.abs(dArr[i] - dArr2[i]);
        }
        return d;
    }

    public static double[] sampleDoublesInInterval(double[][] dArr, int i) {
        double[] dArr2 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr2[i2] = dArr[randomNumberBetween(0.0d, dArr.length - 1)][randomNumberBetween(0.0d, dArr[i2].length)];
        }
        return dArr2;
    }

    public static int randomNumberBetween(double d, double d2) {
        if (d > d2) {
            throw new IllegalArgumentException("Begin must not be less than end");
        }
        return ((int) d) + ((int) (Math.random() * ((d2 - d) + 1.0d)));
    }

    public static int randomNumberBetween(double d, double d2, RandomGenerator randomGenerator) {
        if (d > d2) {
            throw new IllegalArgumentException("Begin must not be less than end");
        }
        return ((int) d) + ((int) (randomGenerator.nextDouble() * ((d2 - d) + 1.0d)));
    }

    public static float randomFloatBetween(float f, float f2) {
        return f + (((float) Math.random()) * (f2 - f));
    }

    public static double randomDoubleBetween(double d, double d2) {
        return d + (Math.random() * (d2 - d));
    }
}
