package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/CenterLossParamInitializer.class */
public class CenterLossParamInitializer extends DefaultParamInitializer {
    private static final CenterLossParamInitializer INSTANCE = new CenterLossParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";
    public static final String CENTER_KEY = "cL";

    public static CenterLossParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        int nIn = feedForwardLayer.getNIn();
        int nOut = feedForwardLayer.getNOut();
        return (nIn * nOut) + nOut + (nIn * nOut);
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        CenterLossOutputLayer centerLossOutputLayer = (CenterLossOutputLayer) neuralNetConfiguration.getLayer();
        int nIn = centerLossOutputLayer.getNIn();
        int nOut = centerLossOutputLayer.getNOut();
        int i = nIn * nOut;
        int i2 = i + nOut;
        int i3 = i2 + (nIn * nOut);
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, i)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i2)});
        INDArray reshape = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i2, i3)}).reshape('c', nOut, nIn);
        synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray2, z));
        synchronizedMap.put("b", createBias(neuralNetConfiguration, iNDArray3, z));
        synchronizedMap.put(CENTER_KEY, createCenterLossMatrix(neuralNetConfiguration, reshape, z));
        neuralNetConfiguration.addVariable("W");
        neuralNetConfiguration.addVariable("b");
        neuralNetConfiguration.addVariable(CENTER_KEY);
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        CenterLossOutputLayer centerLossOutputLayer = (CenterLossOutputLayer) neuralNetConfiguration.getLayer();
        int nIn = centerLossOutputLayer.getNIn();
        int nOut = centerLossOutputLayer.getNOut();
        int i = nIn * nOut;
        int i2 = i + nOut;
        int i3 = i2 + (nIn * nOut);
        INDArray reshape = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, i)}).reshape('f', nIn, nOut);
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i2)});
        INDArray reshape2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i2, i3)}).reshape('c', nOut, nIn);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        linkedHashMap.put("b", iNDArray2);
        linkedHashMap.put(CENTER_KEY, reshape2);
        return linkedHashMap;
    }

    protected INDArray createCenterLossMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (z) {
            iNDArray.assign(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE));
        }
        return iNDArray;
    }
}
