package org.nd4j.linalg.lossfunctions;

import java.util.HashMap;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/SameDiffLoss.class */
public abstract class SameDiffLoss implements ILossFunction {
    protected transient SameDiff sd;
    protected transient SDVariable scorePerExampleVariable;

    protected SameDiffLoss() {
    }

    public abstract SDVariable defineLoss(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2);

    protected void createSameDiffInstance(DataType dataType) {
        this.sd = SameDiff.create();
        this.scorePerExampleVariable = defineLoss(this.sd, this.sd.placeHolder("layerInput", dataType, -1), this.sd.placeHolder("labels", dataType, -1));
        this.scorePerExampleVariable.markAsLoss();
        this.sd.createGradFunction("layerInput");
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        if (this.sd == null) {
            createSameDiffInstance(iNDArray2.dataType());
        }
        double doubleValue = computeScoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (this.sd == null) {
            createSameDiffInstance(iNDArray2.dataType());
        }
        Preconditions.checkArgument(iNDArray.size(1) == iNDArray2.size(1), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", iNDArray.size(1), iNDArray2.size(1));
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        HashMap hashMap = new HashMap();
        hashMap.put("labels", iNDArray);
        hashMap.put("layerInput", activation);
        INDArray outputSingle = this.sd.outputSingle(hashMap, this.scorePerExampleVariable.name());
        if (iNDArray3 != null) {
            LossUtil.applyMask(outputSingle, iNDArray3);
        }
        return outputSingle;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (this.sd == null) {
            createSameDiffInstance(iNDArray2.dataType());
        }
        HashMap hashMap = new HashMap();
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        hashMap.put("labels", iNDArray);
        hashMap.put("layerInput", activation);
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2.dup(), this.sd.calculateGradients(hashMap, "layerInput").get("layerInput")).getFirst();
        if (iNDArray3 != null) {
            LossUtil.applyMask(iNDArray4, iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        Pair<Double, INDArray> pair = new Pair<>();
        pair.setFirst(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)));
        pair.setSecond(computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
        return pair;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return getClass().getSimpleName();
    }
}
