package org.nd4j.linalg.lossfunctions.impl;

import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

@JsonInclude(JsonInclude.Include.NON_NULL)
/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.class */
public class LossMixtureDensity extends DifferentialFunction implements ILossFunction {
    private int mMixtures;
    private int mLabelWidth;
    private static final double SQRT_TWO_PI = Math.sqrt(6.283185307179586d);

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity$Builder.class */
    public static class Builder {
        private int mGaussians;
        private int mLabelWidth;

        private Builder() {
            this.mGaussians = 0;
            this.mLabelWidth = 0;
        }

        public Builder gaussians(int i) {
            this.mGaussians = i;
            return this;
        }

        public Builder labelWidth(int i) {
            this.mLabelWidth = i;
            return this;
        }

        public LossMixtureDensity build() {
            if (this.mGaussians <= 0) {
                throw new IllegalArgumentException("Mixture density cost function must specify the number of mixtures to fit");
            }
            if (this.mLabelWidth <= 0) {
                throw new IllegalArgumentException("Mixture density cost function must specify the size of the labels vectors");
            }
            return new LossMixtureDensity(this.mGaussians, this.mLabelWidth);
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity$MixtureDensityComponents.class */
    public static class MixtureDensityComponents {
        private INDArray alpha;
        private INDArray mu;
        private INDArray sigma;

        public INDArray getAlpha() {
            return this.alpha;
        }

        public INDArray getMu() {
            return this.mu;
        }

        public INDArray getSigma() {
            return this.sigma;
        }

        public void setAlpha(INDArray iNDArray) {
            this.alpha = iNDArray;
        }

        public void setMu(INDArray iNDArray) {
            this.mu = iNDArray;
        }

        public void setSigma(INDArray iNDArray) {
            this.sigma = iNDArray;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof MixtureDensityComponents)) {
                return false;
            }
            MixtureDensityComponents mixtureDensityComponents = (MixtureDensityComponents) obj;
            if (!mixtureDensityComponents.canEqual(this)) {
                return false;
            }
            INDArray alpha = getAlpha();
            INDArray alpha2 = mixtureDensityComponents.getAlpha();
            if (alpha == null) {
                if (alpha2 != null) {
                    return false;
                }
            } else if (!alpha.equals(alpha2)) {
                return false;
            }
            INDArray mu = getMu();
            INDArray mu2 = mixtureDensityComponents.getMu();
            if (mu == null) {
                if (mu2 != null) {
                    return false;
                }
            } else if (!mu.equals(mu2)) {
                return false;
            }
            INDArray sigma = getSigma();
            INDArray sigma2 = mixtureDensityComponents.getSigma();
            return sigma == null ? sigma2 == null : sigma.equals(sigma2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof MixtureDensityComponents;
        }

        public int hashCode() {
            INDArray alpha = getAlpha();
            int hashCode = (1 * 59) + (alpha == null ? 43 : alpha.hashCode());
            INDArray mu = getMu();
            int hashCode2 = (hashCode * 59) + (mu == null ? 43 : mu.hashCode());
            INDArray sigma = getSigma();
            return (hashCode2 * 59) + (sigma == null ? 43 : sigma.hashCode());
        }

        public String toString() {
            return "LossMixtureDensity.MixtureDensityComponents(alpha=" + getAlpha() + ", mu=" + getMu() + ", sigma=" + getSigma() + ")";
        }
    }

    public LossMixtureDensity() {
    }

    private LossMixtureDensity(@JsonProperty("mixtures") int i, @JsonProperty("labelWidth") int i2) {
        this.mMixtures = i;
        this.mLabelWidth = i2;
    }

    public MixtureDensityComponents extractComponents(INDArray iNDArray) {
        int size = iNDArray.size(1);
        if (size != (this.mLabelWidth + 2) * this.mMixtures) {
            throw new IllegalArgumentException("Network output size " + size + " must be (labels+2)*mixtures where labels = " + this.mLabelWidth + " and mixtures = " + this.mMixtures);
        }
        MixtureDensityComponents mixtureDensityComponents = new MixtureDensityComponents();
        mixtureDensityComponents.alpha = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(0, this.mMixtures));
        mixtureDensityComponents.sigma = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures, 2 * this.mMixtures));
        mixtureDensityComponents.mu = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * this.mMixtures, (this.mLabelWidth + 2) * this.mMixtures)).reshape(iNDArray.size(0), this.mMixtures, this.mLabelWidth);
        mixtureDensityComponents.alpha = Nd4j.getExecutioner().execAndReturn((TransformOp) new OldSoftMax(mixtureDensityComponents.alpha));
        mixtureDensityComponents.sigma = Transforms.exp(mixtureDensityComponents.sigma);
        return mixtureDensityComponents;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double doubleValue = computeScoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        MixtureDensityComponents extractComponents = extractComponents(iActivation.getActivation(iNDArray2.dup(), false));
        INDArray negativeLogLikelihood = negativeLogLikelihood(iNDArray, extractComponents.alpha, extractComponents.mu, extractComponents.sigma);
        if (iNDArray3 != null) {
            LossUtil.applyMask(negativeLogLikelihood, iNDArray3);
        }
        return negativeLogLikelihood;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        int size = iNDArray.size(0);
        MixtureDensityComponents extractComponents = extractComponents(iActivation.getActivation(iNDArray2.dup(), false));
        INDArray zeros = Nd4j.zeros(size, iNDArray2.columns());
        INDArray labelsMinusMu = labelsMinusMu(iNDArray, extractComponents.mu);
        INDArray sum = labelsMinusMu.mul(labelsMinusMu).sum(2);
        INDArray mul = extractComponents.sigma.mul(extractComponents.sigma);
        INDArray negi = mul.mul((Number) 2).negi();
        INDArray div = extractComponents.alpha.div(Transforms.pow(extractComponents.sigma.mul(Double.valueOf(SQRT_TWO_PI)), Integer.valueOf(this.mLabelWidth)));
        INDArray div2 = sum.div(negi);
        div2.subiColumnVector(div2.max(1));
        INDArray muli = Transforms.exp(div2).muli(div);
        muli.diviColumnVector(muli.sum(1));
        INDArray sub = extractComponents.alpha.sub(muli);
        INDArray muli2 = sum.div(mul).subi(Integer.valueOf(this.mLabelWidth)).muli((Number) (-1)).muli(muli);
        INDArray create = Nd4j.create(size, this.mMixtures, this.mLabelWidth);
        for (int i = 0; i < this.mLabelWidth; i++) {
            create.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)}, labelsMinusMu.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)).muli(muli).divi(mul).negi());
        }
        INDArray reshape = create.reshape(size, this.mMixtures * this.mLabelWidth);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, this.mMixtures)}, sub);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures, this.mMixtures * 2)}, muli2);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures * 2, (this.mLabelWidth + 2) * this.mMixtures)}, reshape);
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2, zeros).getFirst();
        if (iNDArray3 != null) {
            LossUtil.applyMask(iNDArray4, iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double computeScore = computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z);
        return new Pair<>(Double.valueOf(computeScore), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return "lossmixturedensity";
    }

    private INDArray negativeLogLikelihood(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        INDArray labelsMinusMu = labelsMinusMu(iNDArray, iNDArray3);
        return Transforms.log(phi(labelsMinusMu.mul(labelsMinusMu).sum(2), iNDArray4).muli(iNDArray2).sum(1)).negi();
    }

    private INDArray labelsMinusMu(INDArray iNDArray, INDArray iNDArray2) {
        INDArray zeros = Nd4j.zeros(iNDArray.size(0), this.mMixtures, iNDArray.size(1));
        for (int i = 0; i < this.mMixtures; i++) {
            zeros.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(i), NDArrayIndex.all()}, iNDArray);
        }
        zeros.subi(iNDArray2);
        return zeros;
    }

    private INDArray phi(INDArray iNDArray, INDArray iNDArray2) {
        return Transforms.exp(iNDArray.divi(iNDArray2.mul(iNDArray2).muli((Number) 2).negi())).divi(Transforms.pow(iNDArray2.mul(Double.valueOf(SQRT_TWO_PI)), Double.valueOf(this.mLabelWidth)));
    }

    @JsonProperty("mixtures")
    public int getNMixtures() {
        return this.mMixtures;
    }

    @JsonProperty("labelWidth")
    public int getLabelWidth() {
        return this.mLabelWidth;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return "LossMixtureDensity(mixtures=" + this.mMixtures + ", labels=" + this.mLabelWidth + ")";
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables() {
        return new SDVariable[0];
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables(String str) {
        return new SDVariable[0];
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return null;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return name();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op name found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op name found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossMixtureDensity)) {
            return false;
        }
        LossMixtureDensity lossMixtureDensity = (LossMixtureDensity) obj;
        return lossMixtureDensity.canEqual(this) && this.mMixtures == lossMixtureDensity.mMixtures && this.mLabelWidth == lossMixtureDensity.mLabelWidth;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossMixtureDensity;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        return (((1 * 59) + this.mMixtures) * 59) + this.mLabelWidth;
    }
}
