package org.nd4j.linalg.lossfunctions.impl;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.ILossFunction;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.class */
public class LossSquaredHinge implements ILossFunction {
    public INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        INDArray muli = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(str, iNDArray2.dup())).muli(iNDArray);
        muli.rsubi(Double.valueOf(1.0d));
        if (iNDArray3 != null) {
            muli.muliColumnVector(iNDArray3);
        }
        return muli;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3, boolean z) {
        double doubleValue = computeScoreArray(iNDArray, iNDArray2, str, iNDArray3).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        INDArray scoreArray = scoreArray(iNDArray, iNDArray2, str, iNDArray3);
        BooleanIndexing.replaceWhere(scoreArray, Double.valueOf(0.0d), Conditions.lessThan(Double.valueOf(0.0d)));
        scoreArray.muli(scoreArray);
        return scoreArray.sum(1);
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3) {
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(str, iNDArray2.dup()).derivative());
        INDArray scoreArray = scoreArray(iNDArray, iNDArray2, str, iNDArray3);
        INDArray dup = scoreArray.dup();
        BooleanIndexing.replaceWhere(dup, Double.valueOf(0.0d), Conditions.lessThan(Double.valueOf(0.0d)));
        BooleanIndexing.replaceWhere(dup, Double.valueOf(1.0d), Conditions.greaterThan(Double.valueOf(0.0d)));
        INDArray muli = scoreArray.muli((Number) 2);
        muli.muli(iNDArray.neg());
        muli.muli(dup).muli(execAndReturn);
        if (iNDArray3 != null) {
            muli.muliColumnVector(iNDArray3);
        }
        return muli;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, String str, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, str, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, str, iNDArray3));
    }

    public String toString() {
        return "LossSquaredHinge()";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof LossSquaredHinge) && ((LossSquaredHinge) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossSquaredHinge;
    }

    public int hashCode() {
        return 1;
    }
}
