package org.deeplearning4j.nn.layers;

import java.util.ArrayList;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossCalculation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BasePretrainNetwork.class */
public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork> extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    public INDArray getCorruptedInput(INDArray iNDArray, double d) {
        INDArray sample = Nd4j.getDistributions().createBinomial(1, 1.0d - d).sample(iNDArray.shape());
        sample.muli(iNDArray);
        return sample;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Gradient createGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("vb", iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        defaultGradient.gradientForVariable().put("W", iNDArray);
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public int numParams(boolean z) {
        if (!z) {
            return super.numParams(z);
        }
        int i = 0;
        for (String str : paramTable().keySet()) {
            if (!z) {
                i += getParam(str).length();
            } else if (!str.equals("vb")) {
                i += getParam(str).length();
            }
        }
        return i;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray);

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public void setScoreWithZ(INDArray iNDArray) {
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getLossFunction() != LossFunctions.LossFunction.CUSTOM) {
            this.score = LossCalculation.builder().l1(calcL1()).l2(calcL2()).labels(this.input).z(iNDArray).lossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getLossFunction()).miniBatch(this.conf.isMiniBatch()).miniBatchSize(this.input.size(0)).useRegularization(this.conf.isUseRegularization()).build().score();
            return;
        }
        LossFunction createLossFunction = Nd4j.getOpFactory().createLossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getCustomLossFunction(), this.input, iNDArray);
        createLossFunction.exec();
        this.score = createLossFunction.getFinalResult().doubleValue();
    }

    public INDArray paramsBackprop() {
        ArrayList arrayList = new ArrayList(2);
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!"vb".equals(entry.getKey())) {
                arrayList.add(entry.getValue());
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    public int numParamsBackprop() {
        int i = 0;
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!"vb".equals(entry.getKey())) {
                i += entry.getValue().length();
            }
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        int i = 0;
        int i2 = 0;
        for (String str : this.conf.variables()) {
            int length = getParam(str).length();
            i += length;
            if (!"vb".equals(str)) {
                i2 += length;
            }
        }
        boolean z = iNDArray.length() == i;
        if (!z && iNDArray.length() != i2) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + i + " for pretrain,  or " + i2 + " for backprop. Is: " + iNDArray.length());
        }
        int i3 = 0;
        for (String str2 : this.params.keySet()) {
            if (z || !"vb".equals(str2)) {
                INDArray param = getParam(str2);
                INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i3, i3 + param.length())});
                if (param.length() != iNDArray2.length()) {
                    throw new IllegalStateException("Parameter " + str2 + " should have been of length " + param.length() + " but was " + iNDArray2.length());
                }
                setParam(str2, iNDArray2.reshape('f', param.shape()));
                i3 += param.length();
            }
        }
    }
}
