package edu.stanford.nlp.neural;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.CollectionUtils;
import java.io.File;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.Predicate;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/neural/NeuralUtils.class */
public class NeuralUtils {
    private NeuralUtils() {
    }

    public static SimpleMatrix loadTextMatrix(String str) {
        return convertTextMatrix(IOUtils.slurpFileNoExceptions(str));
    }

    public static SimpleMatrix loadTextMatrix(File file) {
        return convertTextMatrix(IOUtils.slurpFileNoExceptions(file));
    }

    public static SimpleMatrix convertTextMatrix(String str) {
        List filterAsList = CollectionUtils.filterAsList(Arrays.asList(str.split("\n")), new Predicate<String>() { // from class: edu.stanford.nlp.neural.NeuralUtils.1
            private static final long serialVersionUID = 1;

            @Override // java.util.function.Predicate
            public boolean test(String str2) {
                return str2.trim().length() > 0;
            }
        });
        int size = filterAsList.size();
        int length = ((String) filterAsList.get(0)).trim().split("\\s+").length;
        double[][] dArr = new double[size][length];
        for (int i = 0; i < size; i++) {
            String[] split = ((String) filterAsList.get(i)).trim().split("\\s+");
            if (split.length != length) {
                throw new RuntimeException("Unexpected row length in line " + i);
            }
            for (int i2 = 0; i2 < length; i2++) {
                dArr[i][i2] = Double.valueOf(split[i2]).doubleValue();
            }
        }
        return new SimpleMatrix(dArr);
    }

    public static double cosine(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2) {
        return dot(simpleMatrix, simpleMatrix2) / (simpleMatrix.normF() * simpleMatrix2.normF());
    }

    public static double dot(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2) {
        if (simpleMatrix.numRows() == 1) {
            return simpleMatrix.mult(simpleMatrix2.transpose()).get(0);
        }
        if (simpleMatrix.numCols() == 1) {
            return simpleMatrix.transpose().mult(simpleMatrix2).get(0);
        }
        throw new AssertionError("Error in neural.Utils.dot: vector1 is a matrix " + simpleMatrix.numRows() + " x " + simpleMatrix.numCols());
    }

    @SafeVarargs
    public static void vectorToParams(double[] dArr, Iterator<SimpleMatrix>... itArr) {
        int i = 0;
        for (Iterator<SimpleMatrix> it : itArr) {
            while (it.hasNext()) {
                SimpleMatrix next = it.next();
                int numElements = next.getNumElements();
                for (int i2 = 0; i2 < numElements; i2++) {
                    next.set(i2, dArr[i]);
                    i++;
                }
            }
        }
        if (i != dArr.length) {
            throw new AssertionError("Did not entirely use the theta vector");
        }
    }

    @SafeVarargs
    public static double[] paramsToVector(int i, Iterator<SimpleMatrix>... itArr) {
        double[] dArr = new double[i];
        int i2 = 0;
        for (Iterator<SimpleMatrix> it : itArr) {
            while (it.hasNext()) {
                SimpleMatrix next = it.next();
                int numElements = next.getNumElements();
                for (int i3 = 0; i3 < numElements; i3++) {
                    dArr[i2] = next.get(i3);
                    i2++;
                }
            }
        }
        if (i2 != i) {
            throw new AssertionError("Did not entirely fill the theta vector: expected " + i + " used " + i2);
        }
        return dArr;
    }

    @SafeVarargs
    public static double[] paramsToVector(double d, int i, Iterator<SimpleMatrix>... itArr) {
        double[] dArr = new double[i];
        int i2 = 0;
        for (Iterator<SimpleMatrix> it : itArr) {
            while (it.hasNext()) {
                SimpleMatrix next = it.next();
                int numElements = next.getNumElements();
                for (int i3 = 0; i3 < numElements; i3++) {
                    dArr[i2] = next.get(i3) * d;
                    i2++;
                }
            }
        }
        if (i2 != i) {
            throw new AssertionError("Did not entirely fill the theta vector: expected " + i + " used " + i2);
        }
        return dArr;
    }

    public static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public static SimpleMatrix softmax(SimpleMatrix simpleMatrix) {
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(simpleMatrix);
        for (int i = 0; i < simpleMatrix2.numRows(); i++) {
            for (int i2 = 0; i2 < simpleMatrix2.numCols(); i2++) {
                simpleMatrix2.set(i, i2, Math.exp(simpleMatrix2.get(i, i2)));
            }
        }
        return simpleMatrix2.scale(1.0d / simpleMatrix2.elementSum());
    }

    public static SimpleMatrix elementwiseApplyLog(SimpleMatrix simpleMatrix) {
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(simpleMatrix);
        for (int i = 0; i < simpleMatrix2.numRows(); i++) {
            for (int i2 = 0; i2 < simpleMatrix2.numCols(); i2++) {
                simpleMatrix2.set(i, i2, Math.log(simpleMatrix2.get(i, i2)));
            }
        }
        return simpleMatrix2;
    }

    public static SimpleMatrix elementwiseApplyTanh(SimpleMatrix simpleMatrix) {
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(simpleMatrix);
        for (int i = 0; i < simpleMatrix2.numRows(); i++) {
            for (int i2 = 0; i2 < simpleMatrix2.numCols(); i2++) {
                simpleMatrix2.set(i, i2, Math.tanh(simpleMatrix2.get(i, i2)));
            }
        }
        return simpleMatrix2;
    }

    public static SimpleMatrix elementwiseApplyTanhDerivative(SimpleMatrix simpleMatrix) {
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(simpleMatrix.numRows(), simpleMatrix.numCols());
        simpleMatrix2.set(1.0d);
        return simpleMatrix2.minus(simpleMatrix.elementMult(simpleMatrix));
    }

    public static SimpleMatrix concatenateWithBias(SimpleMatrix... simpleMatrixArr) {
        int i = 0;
        for (SimpleMatrix simpleMatrix : simpleMatrixArr) {
            i += simpleMatrix.numRows();
        }
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(i + 1, 1);
        int i2 = 0;
        for (SimpleMatrix simpleMatrix3 : simpleMatrixArr) {
            simpleMatrix2.insertIntoThis(i2, 0, simpleMatrix3);
            i2 += simpleMatrix3.numRows();
        }
        simpleMatrix2.set(i2, 0, 1.0d);
        return simpleMatrix2;
    }

    public static SimpleMatrix concatenate(SimpleMatrix... simpleMatrixArr) {
        int i = 0;
        for (SimpleMatrix simpleMatrix : simpleMatrixArr) {
            i += simpleMatrix.numRows();
        }
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(i, 1);
        int i2 = 0;
        for (SimpleMatrix simpleMatrix3 : simpleMatrixArr) {
            simpleMatrix2.insertIntoThis(i2, 0, simpleMatrix3);
            i2 += simpleMatrix3.numRows();
        }
        return simpleMatrix2;
    }

    public static SimpleMatrix randomGaussian(int i, int i2, Random random) {
        SimpleMatrix simpleMatrix = new SimpleMatrix(i, i2);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                simpleMatrix.set(i3, i4, random.nextGaussian());
            }
        }
        return simpleMatrix;
    }

    public static boolean isZero(SimpleMatrix simpleMatrix) {
        int numElements = simpleMatrix.getNumElements();
        for (int i = 0; i < numElements; i++) {
            if (simpleMatrix.get(i) != 0.0d) {
                return false;
            }
        }
        return true;
    }
}
