/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions.impl;

import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class LossFMeasure
extends DifferentialFunction
implements ILossFunction {
    public static final double DEFAULT_BETA = 1.0;
    private final double beta;

    public LossFMeasure() {
        this(1.0);
    }

    public LossFMeasure(@JsonProperty(value="beta") double beta) {
        if (beta <= 0.0) {
            throw new UnsupportedOperationException("Invalid value: beta must be > 0. Got: " + beta);
        }
        this.beta = beta;
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        double[] d = this.computeScoreNumDenom(labels, preOutput, activationFn, mask, average);
        double numerator = d[0];
        double denominator = d[1];
        if (numerator == 0.0 && denominator == 0.0) {
            return 0.0;
        }
        return 1.0 - numerator / denominator;
    }

    private double[] computeScoreNumDenom(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray pClass1;
        INDArray pClass0;
        INDArray isNegativeLabel;
        INDArray isPositiveLabel;
        INDArray output = activationFn.getActivation(preOutput.dup(), true);
        int n = labels.size(1);
        if (n != 1 && n != 2) {
            throw new UnsupportedOperationException("For binary classification: expect output size of 1 or 2. Got: " + n);
        }
        if (n == 1) {
            isPositiveLabel = labels;
            isNegativeLabel = Transforms.not(isPositiveLabel);
            pClass0 = output.rsub(1.0);
            pClass1 = output;
        } else {
            isPositiveLabel = labels.getColumn(1);
            isNegativeLabel = labels.getColumn(0);
            pClass0 = output.getColumn(0);
            pClass1 = output.getColumn(1);
        }
        if (mask != null) {
            isPositiveLabel = isPositiveLabel.mulColumnVector(mask);
            isNegativeLabel = isNegativeLabel.mulColumnVector(mask);
        }
        double tp = isPositiveLabel.mul(pClass1).sumNumber().doubleValue();
        double fp = isNegativeLabel.mul(pClass1).sumNumber().doubleValue();
        double fn = isPositiveLabel.mul(pClass0).sumNumber().doubleValue();
        double numerator = (1.0 + this.beta * this.beta) * tp;
        double denominator = (1.0 + this.beta * this.beta) * tp + this.beta * this.beta * fn + fp;
        return new double[]{numerator, denominator};
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        throw new UnsupportedOperationException("Cannot compute score array for FMeasure loss function: loss is only defined for minibatches");
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray dLdOut;
        double[] d = this.computeScoreNumDenom(labels, preOutput, activationFn, mask, false);
        double numerator = d[0];
        double denominator = d[1];
        if (numerator == 0.0 && denominator == 0.0) {
            return Nd4j.create(preOutput.shape());
        }
        double secondTerm = numerator / (denominator * denominator);
        if (labels.size(1) == 1) {
            dLdOut = labels.mul(1.0 + this.beta * this.beta).divi(denominator).subi(secondTerm);
        } else {
            dLdOut = Nd4j.create(labels.shape());
            dLdOut.getColumn(1).assign(labels.getColumn(1).mul(1.0 + this.beta * this.beta).divi(denominator).subi(secondTerm));
        }
        dLdOut.negi();
        INDArray dLdPreOut = (INDArray)activationFn.backprop(preOutput, dLdOut).getFirst();
        if (mask != null) {
            dLdPreOut.muliColumnVector(mask);
        }
        return dLdPreOut;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        return new Pair((Object)this.computeScore(labels, preOutput, activationFn, mask, average), (Object)this.computeGradient(labels, preOutput, activationFn, mask));
    }

    @Override
    public String name() {
        return "floss";
    }

    @Override
    public String toString() {
        return "LossFMeasure(beta=" + this.beta + ")";
    }

    @Override
    public SDVariable[] outputVariables() {
        return new SDVariable[0];
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        return new SDVariable[0];
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        return null;
    }

    @Override
    public String opName() {
        return this.name();
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op name found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op name found for " + this.opName());
    }

    public double getBeta() {
        return this.beta;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossFMeasure)) {
            return false;
        }
        LossFMeasure other = (LossFMeasure)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return Double.compare(this.getBeta(), other.getBeta()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossFMeasure;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $beta = Double.doubleToLongBits(this.getBeta());
        result = result * 59 + (int)($beta >>> 32 ^ $beta);
        return result;
    }
}

