package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil.class */
public class GradientCheckUtil {
    private static Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);

    private GradientCheckUtil() {
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (!(multiLayerNetwork.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        for (NeuralNetConfiguration neuralNetConfiguration : multiLayerNetwork.getLayerWiseConfigurations().getConfs()) {
            Updater updater = neuralNetConfiguration.getLayer().getUpdater();
            if (updater == Updater.SGD) {
                double learningRate = neuralNetConfiguration.getLayer().getLearningRate();
                if (learningRate != 1.0d) {
                    throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer 0; got " + updater + " with lr=" + learningRate);
                }
            } else if (updater != Updater.NONE) {
                throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer 0; got " + updater);
            }
        }
        multiLayerNetwork.setInput(iNDArray);
        multiLayerNetwork.setLabels(iNDArray2);
        multiLayerNetwork.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = multiLayerNetwork.gradientAndScore();
        UpdaterCreator.getUpdater(multiLayerNetwork).update(multiLayerNetwork, gradientAndScore.getFirst(), 0, multiLayerNetwork.batchSize());
        INDArray dup = gradientAndScore.getFirst().gradient().dup();
        int length = multiLayerNetwork.params().dup().length();
        Map<String, INDArray> paramTable = multiLayerNetwork.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        int[] iArr = new int[arrayList.size()];
        iArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i = 1; i < iArr.length; i++) {
            iArr[i] = iArr[i - 1] + paramTable.get(arrayList.get(i)).length();
        }
        int i2 = 0;
        double d4 = 0.0d;
        DataSet dataSet = new DataSet(iNDArray, iNDArray2);
        int i3 = 0;
        INDArray params = multiLayerNetwork.params();
        for (int i4 = 0; i4 < length; i4++) {
            if (i4 >= iArr[i3]) {
                i3++;
            }
            String str = (String) arrayList.get(i3);
            double d5 = params.getDouble(i4);
            params.putScalar(i4, d5 + d);
            double score = multiLayerNetwork.score(dataSet, true);
            params.putScalar(i4, d5 - d);
            double score2 = multiLayerNetwork.score(dataSet, true);
            params.putScalar(i4, d5);
            double d6 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d6)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i4 + " of " + length);
            }
            double d7 = dup.getDouble(i4);
            double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
            if (d7 == 0.0d && d6 == 0.0d) {
                abs = 0.0d;
            }
            if (abs > d4) {
                d4 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                double abs2 = Math.abs(d7 - d6);
                if (abs2 < d3) {
                    log.info("Param " + i4 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                } else {
                    if (z) {
                        log.info("Param " + i4 + " (" + str + ") FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2);
                    }
                    if (z2) {
                        return false;
                    }
                    i2++;
                }
            } else if (z) {
                log.info("Param " + i4 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i2) + " passed, " + i2 + " failed. Largest relative error = " + d4);
        }
        return i2 == 0;
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (computationGraph.getNumInputArrays() != iNDArrayArr.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + computationGraph.getNumInputArrays() + " inputs");
        }
        if (computationGraph.getNumOutputArrays() != iNDArrayArr2.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + computationGraph.getNumOutputArrays() + " outputs");
        }
        for (String str : computationGraph.getConfiguration().getVertices().keySet()) {
            GraphVertex graphVertex = computationGraph.getConfiguration().getVertices().get(str);
            if (graphVertex instanceof LayerVertex) {
                LayerVertex layerVertex = (LayerVertex) graphVertex;
                Updater updater = layerVertex.getLayerConf().getLayer().getUpdater();
                if (updater == Updater.SGD) {
                    if (layerVertex.getLayerConf().getLayer().getLearningRate() != 1.0d) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer \"" + str + "\"; got " + updater);
                    }
                } else if (updater != Updater.NONE) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer \"" + str + "\"; got " + updater);
                }
            }
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            computationGraph.setInput(i, iNDArrayArr[i]);
        }
        for (int i2 = 0; i2 < iNDArrayArr2.length; i2++) {
            computationGraph.setLabel(i2, iNDArrayArr2[i2]);
        }
        computationGraph.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = computationGraph.gradientAndScore();
        new ComputationGraphUpdater(computationGraph).update(computationGraph, gradientAndScore.getFirst(), 0, computationGraph.batchSize());
        INDArray dup = gradientAndScore.getFirst().gradient().dup();
        int length = computationGraph.params().dup().length();
        int i3 = 0;
        double d4 = 0.0d;
        MultiDataSet multiDataSet = new MultiDataSet(iNDArrayArr, iNDArrayArr2);
        INDArray params = computationGraph.params();
        for (int i4 = 0; i4 < length; i4++) {
            double d5 = params.getDouble(i4);
            params.putScalar(i4, d5 + d);
            double score = computationGraph.score((org.nd4j.linalg.dataset.api.MultiDataSet) multiDataSet);
            params.putScalar(i4, d5 - d);
            double score2 = computationGraph.score((org.nd4j.linalg.dataset.api.MultiDataSet) multiDataSet);
            params.putScalar(i4, d5);
            double d6 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d6)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i4 + " of " + length);
            }
            double d7 = dup.getDouble(i4);
            double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
            if (d7 == 0.0d && d6 == 0.0d) {
                abs = 0.0d;
            }
            if (abs > d4) {
                d4 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                double abs2 = Math.abs(d7 - d6);
                if (abs2 < d3) {
                    log.info("Param " + i4 + " passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                } else {
                    if (z) {
                        log.info("Param " + i4 + " FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2);
                    }
                    if (z2) {
                        return false;
                    }
                    i3++;
                }
            } else if (z) {
                log.info("Param " + i4 + " passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i3) + " passed, " + i3 + " failed. Largest relative error = " + d4);
        }
        return i3 == 0;
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, int i) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        layer.setInput(iNDArray);
        Nd4j.getRandom().setSeed(i);
        layer.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = layer.gradientAndScore();
        UpdaterCreator.getUpdater(layer).update(layer, gradientAndScore.getFirst(), 0, layer.batchSize());
        INDArray dup = gradientAndScore.getFirst().gradient().dup();
        int length = layer.params().dup().length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        int[] iArr = new int[arrayList.size()];
        iArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i2 = 1; i2 < iArr.length; i2++) {
            iArr[i2] = iArr[i2 - 1] + paramTable.get(arrayList.get(i2)).length();
        }
        int i3 = 0;
        double d4 = 0.0d;
        int i4 = 0;
        INDArray params = layer.params();
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 >= iArr[i4]) {
                i4++;
            }
            String str = (String) arrayList.get(i4);
            double d5 = params.getDouble(i5);
            params.putScalar(i5, d5 + d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore();
            double score = layer.score();
            params.putScalar(i5, d5 - d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore();
            double score2 = layer.score();
            params.putScalar(i5, d5);
            double d6 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d6)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i5 + " of " + length);
            }
            double d7 = dup.getDouble(i5);
            double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
            if (d7 == 0.0d && d6 == 0.0d) {
                abs = 0.0d;
            }
            if (abs > d4) {
                d4 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                double abs2 = Math.abs(d7 - d6);
                if (abs2 < d3) {
                    log.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                } else {
                    if (z) {
                        log.info("Param " + i5 + " (" + str + ") FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2);
                    }
                    if (z2) {
                        return false;
                    }
                    i3++;
                }
            } else if (z) {
                log.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i3) + " passed, " + i3 + " failed. Largest relative error = " + d4);
        }
        return i3 == 0;
    }
}
