package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.conditions.Or;
import org.nd4j.linalg.indexing.functions.StableNumber;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossCalculation.class */
public class LossCalculation {
    private INDArray labels;
    private INDArray z;
    private double l1;
    private double l2;
    private LossFunctions.LossFunction lossFunction;
    private boolean useRegularization;
    private boolean miniBatch;
    private int miniBatchSize;
    private String activationFn;
    private INDArray preOut;
    private INDArray mask;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.lossfunctions.LossCalculation$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossCalculation$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type = new int[DataBuffer.Type.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction = new int[LossFunctions.LossFunction.values().length];
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.MCXENT.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.MSE.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.CUSTOM.ordinal()] = 5;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.XENT.ordinal()] = 6;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.RMSE_XENT.ordinal()] = 7;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.EXPLL.ordinal()] = 8;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.SQUARED_LOSS.ordinal()] = 9;
            } catch (NoSuchFieldError e12) {
            }
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossCalculation$LossCalculationBuilder.class */
    public static class LossCalculationBuilder {
        private INDArray labels;
        private INDArray z;
        private double l1;
        private double l2;
        private LossFunctions.LossFunction lossFunction;
        private boolean useRegularization;
        private boolean miniBatch;
        private int miniBatchSize;
        private String activationFn;
        private INDArray preOut;
        private INDArray mask;

        LossCalculationBuilder() {
        }

        public LossCalculationBuilder labels(INDArray iNDArray) {
            this.labels = iNDArray;
            return this;
        }

        public LossCalculationBuilder z(INDArray iNDArray) {
            this.z = iNDArray;
            return this;
        }

        public LossCalculationBuilder l1(double d) {
            this.l1 = d;
            return this;
        }

        public LossCalculationBuilder l2(double d) {
            this.l2 = d;
            return this;
        }

        public LossCalculationBuilder lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public LossCalculationBuilder useRegularization(boolean z) {
            this.useRegularization = z;
            return this;
        }

        public LossCalculationBuilder miniBatch(boolean z) {
            this.miniBatch = z;
            return this;
        }

        public LossCalculationBuilder miniBatchSize(int i) {
            this.miniBatchSize = i;
            return this;
        }

        public LossCalculationBuilder activationFn(String str) {
            this.activationFn = str;
            return this;
        }

        public LossCalculationBuilder preOut(INDArray iNDArray) {
            this.preOut = iNDArray;
            return this;
        }

        public LossCalculationBuilder mask(INDArray iNDArray) {
            this.mask = iNDArray;
            return this;
        }

        public LossCalculation build() {
            return new LossCalculation(this.labels, this.z, this.l1, this.l2, this.lossFunction, this.useRegularization, this.miniBatch, this.miniBatchSize, this.activationFn, this.preOut, this.mask);
        }

        public String toString() {
            return "LossCalculation.LossCalculationBuilder(labels=" + this.labels + ", z=" + this.z + ", l1=" + this.l1 + ", l2=" + this.l2 + ", lossFunction=" + this.lossFunction + ", useRegularization=" + this.useRegularization + ", miniBatch=" + this.miniBatch + ", miniBatchSize=" + this.miniBatchSize + ", activationFn=" + this.activationFn + ", preOut=" + this.preOut + ", mask=" + this.mask + ")";
        }
    }

    public double score() {
        double doubleValue = scoreArray().sumNumber().doubleValue();
        switch (this.lossFunction) {
            case MCXENT:
            case NEGATIVELOGLIKELIHOOD:
            case RECONSTRUCTION_CROSSENTROPY:
                doubleValue *= -1.0d;
                break;
            case MSE:
                doubleValue *= 0.5d;
                break;
        }
        if (this.useRegularization) {
            doubleValue += this.l1 + this.l2;
        }
        if (this.miniBatch) {
            doubleValue /= this.miniBatchSize;
        }
        return doubleValue;
    }

    public INDArray scoreExamples() {
        INDArray sum = scoreArray().sum(1);
        switch (this.lossFunction) {
            case MCXENT:
            case NEGATIVELOGLIKELIHOOD:
            case RECONSTRUCTION_CROSSENTROPY:
                sum.muli((Number) (-1));
                break;
            case MSE:
                sum.muli(Double.valueOf(0.5d));
                break;
        }
        double d = this.l1 + this.l2;
        if (this.useRegularization && d != 0.0d) {
            sum.addi(Double.valueOf(d));
        }
        return sum;
    }

    private INDArray scoreArray() {
        INDArray iNDArray;
        switch (this.lossFunction) {
            case MCXENT:
            case NEGATIVELOGLIKELIHOOD:
                if (this.preOut != null && "softmax".equals(this.activationFn)) {
                    INDArray mul = this.labels.mul(Nd4j.getExecutioner().execAndReturn((TransformOp) new LogSoftMax(this.preOut.dup())));
                    if (this.mask != null) {
                        mul.muliColumnVector(this.mask);
                    }
                    iNDArray = mul;
                    break;
                } else {
                    INDArray mul2 = this.labels.mul(logZ(this.z));
                    if (this.mask != null) {
                        mul2.muliColumnVector(this.mask);
                    }
                    iNDArray = mul2;
                    break;
                }
            case RECONSTRUCTION_CROSSENTROPY:
                INDArray logZ = logZ(this.z);
                INDArray muli = this.labels.mul(logZ).add(this.labels.rsub((Number) 1)).muli(logZ.rsubi((Number) 1));
                if (this.mask != null) {
                    muli.muliColumnVector(this.mask);
                }
                iNDArray = muli;
                break;
            case MSE:
                INDArray sub = this.labels.sub(this.z);
                sub.muli(sub);
                if (this.mask != null) {
                    sub.muliColumnVector(this.mask);
                }
                iNDArray = sub;
                break;
            case CUSTOM:
                throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
            case XENT:
                INDArray logZ2 = logZ(this.z);
                INDArray muli2 = this.labels.mul(logZ2).add(this.labels.rsub((Number) 1)).muli(logZ2.dup().rsubi((Number) 1));
                if (this.mask != null) {
                    muli2.muliColumnVector(this.mask);
                }
                iNDArray = muli2;
                break;
            case RMSE_XENT:
                INDArray sub2 = this.labels.sub(this.z);
                INDArray sqrt = Transforms.sqrt(sub2.muli(sub2));
                if (this.mask != null) {
                    sqrt.muliColumnVector(this.mask);
                }
                iNDArray = sqrt;
                break;
            case EXPLL:
                INDArray sub3 = this.z.sub(this.labels.mul(logZ(this.z)));
                if (this.mask != null) {
                    sub3.muliColumnVector(this.mask);
                }
                iNDArray = sub3;
                break;
            case SQUARED_LOSS:
                INDArray sub4 = this.labels.sub(this.z);
                sub4.muli(sub4);
                if (this.mask != null) {
                    sub4.muliColumnVector(this.mask);
                }
                iNDArray = sub4;
                break;
            default:
                throw new RuntimeException("Unknown loss function: " + this.lossFunction);
        }
        return iNDArray;
    }

    private static INDArray logZ(INDArray iNDArray) {
        INDArray log = Transforms.log(iNDArray, true);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[log.data().dataType().ordinal()]) {
            case 1:
                BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.FLOAT));
                break;
            case 2:
                BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.DOUBLE));
                break;
            case 3:
                BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new Value(-2147483647));
                break;
            default:
                throw new RuntimeException("unsupported data type: " + log.data().dataType());
        }
        return log;
    }

    LossCalculation(INDArray iNDArray, INDArray iNDArray2, double d, double d2, LossFunctions.LossFunction lossFunction, boolean z, boolean z2, int i, String str, INDArray iNDArray3, INDArray iNDArray4) {
        this.miniBatch = false;
        this.labels = iNDArray;
        this.z = iNDArray2;
        this.l1 = d;
        this.l2 = d2;
        this.lossFunction = lossFunction;
        this.useRegularization = z;
        this.miniBatch = z2;
        this.miniBatchSize = i;
        this.activationFn = str;
        this.preOut = iNDArray3;
        this.mask = iNDArray4;
    }

    public static LossCalculationBuilder builder() {
        return new LossCalculationBuilder();
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getZ() {
        return this.z;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public LossFunctions.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean isUseRegularization() {
        return this.useRegularization;
    }

    public boolean isMiniBatch() {
        return this.miniBatch;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public String getActivationFn() {
        return this.activationFn;
    }

    public INDArray getPreOut() {
        return this.preOut;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public void setZ(INDArray iNDArray) {
        this.z = iNDArray;
    }

    public void setL1(double d) {
        this.l1 = d;
    }

    public void setL2(double d) {
        this.l2 = d;
    }

    public void setLossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public void setUseRegularization(boolean z) {
        this.useRegularization = z;
    }

    public void setMiniBatch(boolean z) {
        this.miniBatch = z;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public void setActivationFn(String str) {
        this.activationFn = str;
    }

    public void setPreOut(INDArray iNDArray) {
        this.preOut = iNDArray;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossCalculation)) {
            return false;
        }
        LossCalculation lossCalculation = (LossCalculation) obj;
        if (!lossCalculation.canEqual(this)) {
            return false;
        }
        INDArray labels = getLabels();
        INDArray labels2 = lossCalculation.getLabels();
        if (labels == null) {
            if (labels2 != null) {
                return false;
            }
        } else if (!labels.equals(labels2)) {
            return false;
        }
        INDArray z = getZ();
        INDArray z2 = lossCalculation.getZ();
        if (z == null) {
            if (z2 != null) {
                return false;
            }
        } else if (!z.equals(z2)) {
            return false;
        }
        if (Double.compare(getL1(), lossCalculation.getL1()) != 0 || Double.compare(getL2(), lossCalculation.getL2()) != 0) {
            return false;
        }
        LossFunctions.LossFunction lossFunction = getLossFunction();
        LossFunctions.LossFunction lossFunction2 = lossCalculation.getLossFunction();
        if (lossFunction == null) {
            if (lossFunction2 != null) {
                return false;
            }
        } else if (!lossFunction.equals(lossFunction2)) {
            return false;
        }
        if (isUseRegularization() != lossCalculation.isUseRegularization() || isMiniBatch() != lossCalculation.isMiniBatch() || getMiniBatchSize() != lossCalculation.getMiniBatchSize()) {
            return false;
        }
        String activationFn = getActivationFn();
        String activationFn2 = lossCalculation.getActivationFn();
        if (activationFn == null) {
            if (activationFn2 != null) {
                return false;
            }
        } else if (!activationFn.equals(activationFn2)) {
            return false;
        }
        INDArray preOut = getPreOut();
        INDArray preOut2 = lossCalculation.getPreOut();
        if (preOut == null) {
            if (preOut2 != null) {
                return false;
            }
        } else if (!preOut.equals(preOut2)) {
            return false;
        }
        INDArray mask = getMask();
        INDArray mask2 = lossCalculation.getMask();
        return mask == null ? mask2 == null : mask.equals(mask2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossCalculation;
    }

    public int hashCode() {
        INDArray labels = getLabels();
        int hashCode = (1 * 59) + (labels == null ? 0 : labels.hashCode());
        INDArray z = getZ();
        int hashCode2 = (hashCode * 59) + (z == null ? 0 : z.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getL1());
        int i = (hashCode2 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getL2());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        LossFunctions.LossFunction lossFunction = getLossFunction();
        int hashCode3 = (((((((i2 * 59) + (lossFunction == null ? 0 : lossFunction.hashCode())) * 59) + (isUseRegularization() ? 79 : 97)) * 59) + (isMiniBatch() ? 79 : 97)) * 59) + getMiniBatchSize();
        String activationFn = getActivationFn();
        int hashCode4 = (hashCode3 * 59) + (activationFn == null ? 0 : activationFn.hashCode());
        INDArray preOut = getPreOut();
        int hashCode5 = (hashCode4 * 59) + (preOut == null ? 0 : preOut.hashCode());
        INDArray mask = getMask();
        return (hashCode5 * 59) + (mask == null ? 0 : mask.hashCode());
    }

    public String toString() {
        return "LossCalculation(labels=" + getLabels() + ", z=" + getZ() + ", l1=" + getL1() + ", l2=" + getL2() + ", lossFunction=" + getLossFunction() + ", useRegularization=" + isUseRegularization() + ", miniBatch=" + isMiniBatch() + ", miniBatchSize=" + getMiniBatchSize() + ", activationFn=" + getActivationFn() + ", preOut=" + getPreOut() + ", mask=" + getMask() + ")";
    }
}
