/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.beans.ConstructorProperties;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.BaseEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class ROC
extends BaseEvaluation<ROC> {
    private final int thresholdSteps;
    private long countActualPositive;
    private long countActualNegative;
    private final Map<Double, CountsForThreshold> counts = new LinkedHashMap<Double, CountsForThreshold>();

    public ROC(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
        double step = 1.0 / (double)thresholdSteps;
        for (int i = 0; i <= thresholdSteps; ++i) {
            double currThreshold = (double)i * step;
            this.counts.put(currThreshold, new CountsForThreshold(currThreshold));
        }
    }

    @Override
    public void eval(INDArray labels, INDArray predictions) {
        INDArray positivePredictedClassColumn;
        INDArray negativeActualClassColumn;
        INDArray positiveActualClassColumn;
        boolean singleOutput;
        if (labels.rank() == 3 && predictions.rank() == 3) {
            this.evalTimeSeries(labels, predictions);
        }
        if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        double step = 1.0 / (double)this.thresholdSteps;
        boolean bl = singleOutput = labels.size(1) == 1;
        if (singleOutput) {
            positiveActualClassColumn = labels;
            negativeActualClassColumn = labels.rsub((Number)1.0);
            positivePredictedClassColumn = predictions;
        } else {
            positiveActualClassColumn = labels.getColumn(1);
            negativeActualClassColumn = labels.getColumn(0);
            positivePredictedClassColumn = predictions.getColumn(1);
        }
        this.countActualPositive += (long)positiveActualClassColumn.sumNumber().intValue();
        this.countActualNegative += (long)negativeActualClassColumn.sumNumber().intValue();
        for (int i = 0; i <= this.thresholdSteps; ++i) {
            double currThreshold = (double)i * step;
            Condition condGeq = Conditions.greaterThanOrEqual((Number)currThreshold);
            Condition condLeq = Conditions.lessThanOrEqual((Number)currThreshold);
            CompareAndSet op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq);
            INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
            op = new CompareAndSet(predictedClass1, 0.0, condLeq);
            predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
            INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn);
            INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
            int truePositiveCount = isTruePositive.sumNumber().intValue();
            int falsePositiveCount = isFalsePositive.sumNumber().intValue();
            CountsForThreshold thresholdCounts = this.counts.get(currThreshold);
            thresholdCounts.incrementTruePositive(truePositiveCount);
            thresholdCounts.incrementFalsePositive(falsePositiveCount);
        }
    }

    public List<ROCValue> getResults() {
        ArrayList<ROCValue> out = new ArrayList<ROCValue>(this.counts.size());
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            double t = entry.getKey();
            CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive;
            double fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative;
            out.add(new ROCValue(t, tpr, fpr));
        }
        return out;
    }

    public List<PrecisionRecallPoint> getPrecisionRecallCurve() {
        ArrayList<PrecisionRecallPoint> out = new ArrayList<PrecisionRecallPoint>(this.counts.size());
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            double t = entry.getKey();
            CountsForThreshold c = entry.getValue();
            long tpCount = c.getCountTruePositive();
            long fpCount = c.getCountFalsePositive();
            double precision = tpCount == 0L && fpCount == 0L ? 1.0 : (double)tpCount / (double)(tpCount + fpCount);
            double recall = this.countActualPositive == 0L ? 1.0 : (double)tpCount / (double)this.countActualPositive;
            out.add(new PrecisionRecallPoint(c.getThreshold(), precision, recall));
        }
        return out;
    }

    public double[][] getResultsAsArray() {
        double[][] out = new double[2][this.thresholdSteps + 1];
        int i = 0;
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            double fpr;
            CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive;
            out[0][i] = fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative;
            out[1][i] = tpr;
            ++i;
        }
        return out;
    }

    public double calculateAUC() {
        List<ROCValue> list = this.getResults();
        double auc = 0.0;
        for (int i = 0; i < list.size() - 1; ++i) {
            ROCValue left = list.get(i);
            ROCValue right = list.get(i + 1);
            double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate());
            double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0;
            auc += deltaX * avg;
        }
        return auc;
    }

    @Override
    public void merge(ROC other) {
        if (this.thresholdSteps != other.thresholdSteps) {
            throw new UnsupportedOperationException("Cannot merge ROC instances with different numbers of threshold steps (" + this.thresholdSteps + " vs. " + other.thresholdSteps + ")");
        }
        this.countActualPositive += other.countActualPositive;
        this.countActualNegative += other.countActualNegative;
        for (Double d : this.counts.keySet()) {
            CountsForThreshold cft = this.counts.get(d);
            CountsForThreshold otherCft = other.counts.get(d);
            CountsForThreshold countsForThreshold = cft;
            countsForThreshold.countTruePositive = countsForThreshold.countTruePositive + otherCft.countTruePositive;
            countsForThreshold = cft;
            countsForThreshold.countFalsePositive = countsForThreshold.countFalsePositive + otherCft.countFalsePositive;
        }
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long getCountActualPositive() {
        return this.countActualPositive;
    }

    public long getCountActualNegative() {
        return this.countActualNegative;
    }

    public Map<Double, CountsForThreshold> getCounts() {
        return this.counts;
    }

    public static class CountsForThreshold
    implements Serializable,
    Cloneable {
        private double threshold;
        private long countTruePositive;
        private long countFalsePositive;

        public CountsForThreshold(double threshold) {
            this(threshold, 0L, 0L);
        }

        public void incrementTruePositive(long count) {
            this.countTruePositive += count;
        }

        public void incrementFalsePositive(long count) {
            this.countFalsePositive += count;
        }

        public CountsForThreshold clone() {
            return new CountsForThreshold(this.threshold, this.countTruePositive, this.countFalsePositive);
        }

        @ConstructorProperties(value={"threshold", "countTruePositive", "countFalsePositive"})
        public CountsForThreshold(double threshold, long countTruePositive, long countFalsePositive) {
            this.threshold = threshold;
            this.countTruePositive = countTruePositive;
            this.countFalsePositive = countFalsePositive;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public long getCountTruePositive() {
            return this.countTruePositive;
        }

        public long getCountFalsePositive() {
            return this.countFalsePositive;
        }

        public void setThreshold(double threshold) {
            this.threshold = threshold;
        }

        public void setCountTruePositive(long countTruePositive) {
            this.countTruePositive = countTruePositive;
        }

        public void setCountFalsePositive(long countFalsePositive) {
            this.countFalsePositive = countFalsePositive;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof CountsForThreshold)) {
                return false;
            }
            CountsForThreshold other = (CountsForThreshold)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (this.getCountTruePositive() != other.getCountTruePositive()) {
                return false;
            }
            return this.getCountFalsePositive() == other.getCountFalsePositive();
        }

        protected boolean canEqual(Object other) {
            return other instanceof CountsForThreshold;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $countTruePositive = this.getCountTruePositive();
            result = result * 59 + (int)($countTruePositive >>> 32 ^ $countTruePositive);
            long $countFalsePositive = this.getCountFalsePositive();
            result = result * 59 + (int)($countFalsePositive >>> 32 ^ $countFalsePositive);
            return result;
        }

        public String toString() {
            return "ROC.CountsForThreshold(threshold=" + this.getThreshold() + ", countTruePositive=" + this.getCountTruePositive() + ", countFalsePositive=" + this.getCountFalsePositive() + ")";
        }
    }

    public static class PrecisionRecallPoint {
        private final double classiferThreshold;
        private final double precision;
        private final double recall;

        @ConstructorProperties(value={"classiferThreshold", "precision", "recall"})
        public PrecisionRecallPoint(double classiferThreshold, double precision, double recall) {
            this.classiferThreshold = classiferThreshold;
            this.precision = precision;
            this.recall = recall;
        }

        public double getClassiferThreshold() {
            return this.classiferThreshold;
        }

        public double getPrecision() {
            return this.precision;
        }

        public double getRecall() {
            return this.recall;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof PrecisionRecallPoint)) {
                return false;
            }
            PrecisionRecallPoint other = (PrecisionRecallPoint)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getClassiferThreshold(), other.getClassiferThreshold()) != 0) {
                return false;
            }
            if (Double.compare(this.getPrecision(), other.getPrecision()) != 0) {
                return false;
            }
            return Double.compare(this.getRecall(), other.getRecall()) == 0;
        }

        protected boolean canEqual(Object other) {
            return other instanceof PrecisionRecallPoint;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $classiferThreshold = Double.doubleToLongBits(this.getClassiferThreshold());
            result = result * 59 + (int)($classiferThreshold >>> 32 ^ $classiferThreshold);
            long $precision = Double.doubleToLongBits(this.getPrecision());
            result = result * 59 + (int)($precision >>> 32 ^ $precision);
            long $recall = Double.doubleToLongBits(this.getRecall());
            result = result * 59 + (int)($recall >>> 32 ^ $recall);
            return result;
        }

        public String toString() {
            return "ROC.PrecisionRecallPoint(classiferThreshold=" + this.getClassiferThreshold() + ", precision=" + this.getPrecision() + ", recall=" + this.getRecall() + ")";
        }
    }

    public static class ROCValue {
        private final double threshold;
        private final double truePositiveRate;
        private final double falsePositiveRate;

        @ConstructorProperties(value={"threshold", "truePositiveRate", "falsePositiveRate"})
        public ROCValue(double threshold, double truePositiveRate, double falsePositiveRate) {
            this.threshold = threshold;
            this.truePositiveRate = truePositiveRate;
            this.falsePositiveRate = falsePositiveRate;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public double getTruePositiveRate() {
            return this.truePositiveRate;
        }

        public double getFalsePositiveRate() {
            return this.falsePositiveRate;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ROCValue)) {
                return false;
            }
            ROCValue other = (ROCValue)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (Double.compare(this.getTruePositiveRate(), other.getTruePositiveRate()) != 0) {
                return false;
            }
            return Double.compare(this.getFalsePositiveRate(), other.getFalsePositiveRate()) == 0;
        }

        protected boolean canEqual(Object other) {
            return other instanceof ROCValue;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $truePositiveRate = Double.doubleToLongBits(this.getTruePositiveRate());
            result = result * 59 + (int)($truePositiveRate >>> 32 ^ $truePositiveRate);
            long $falsePositiveRate = Double.doubleToLongBits(this.getFalsePositiveRate());
            result = result * 59 + (int)($falsePositiveRate >>> 32 ^ $falsePositiveRate);
            return result;
        }

        public String toString() {
            return "ROC.ROCValue(threshold=" + this.getThreshold() + ", truePositiveRate=" + this.getTruePositiveRate() + ", falsePositiveRate=" + this.getFalsePositiveRate() + ")";
        }
    }
}

