package org.deeplearning4j.nn.updater;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/updater/BaseUpdater.class */
public abstract class BaseUpdater implements Updater {
    protected Map<String, GradientUpdater> updaterForVariable = new HashMap();

    @Override // org.deeplearning4j.nn.api.Updater
    public void update(Layer layer, Gradient gradient, int i) {
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            INDArray gradient2 = init(entry.getKey(), entry.getValue(), layer).getGradient(entry.getValue(), i);
            postApply(layer, gradient2, entry.getKey());
            gradient.setGradientFor(entry.getKey(), gradient2);
        }
    }

    public void postApply(Layer layer, INDArray iNDArray, String str) {
        NeuralNetConfiguration conf = layer.conf();
        INDArray param = layer.getParam(str);
        if (conf.isUseRegularization() && conf.getL2() > 0.0d && !str.equals("b")) {
            iNDArray.addi(param.mul(Double.valueOf(conf.getL2())));
        }
        if (conf.isUseRegularization() && conf.getL1() > 0.0d && !str.equals("b")) {
            iNDArray.addi(Transforms.sign(param).muli(Double.valueOf(conf.getL1())));
        }
        if (conf.isMiniBatch()) {
            iNDArray.divi(Integer.valueOf(layer.getInputMiniBatchSize()));
        }
        if (conf.isConstrainGradientToUnitNorm()) {
            iNDArray.divi(iNDArray.norm2(new int[]{Integer.MAX_VALUE}));
        }
    }

    public abstract void init();

    public abstract GradientUpdater init(String str, INDArray iNDArray, Layer layer);
}
