/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Evaluation<T extends Comparable<? super T>>
implements Serializable {
    private Counter<Integer> truePositives = new Counter();
    private Counter<Integer> falsePositives = new Counter();
    private Counter<Integer> trueNegatives = new Counter();
    private Counter<Integer> falseNegatives = new Counter();
    private ConfusionMatrix<Integer> confusion;
    private int numRowCounter = 0;
    private List<Integer> labelsList = new ArrayList<Integer>();
    private Map<Integer, String> labelsMap = new HashMap<Integer, String>();
    private static Logger log = LoggerFactory.getLogger(Evaluation.class);
    private static final double DEFAULT_EDGE_VALUE = 0.0;

    public Evaluation() {
    }

    public Evaluation(int numClasses) {
        for (int i = 0; i < numClasses; ++i) {
            this.labelsList.add(i);
        }
        this.confusion = new ConfusionMatrix<Integer>(this.labelsList);
    }

    public Evaluation(List<String> labels) {
        int i = 0;
        for (String label : labels) {
            this.labelsMap.put(i, label);
            ++i;
        }
    }

    public Evaluation(Map<Integer, String> labels) {
        this.labelsMap = labels;
    }

    public void eval(INDArray realOutcomes, INDArray guesses) {
        this.numRowCounter += realOutcomes.shape()[0];
        if (this.confusion == null) {
            log.warn("Creating confusion matrix based on classes passed in . Will assume the label distribution passed in is indicative of the overall dataset");
            HashSet<Integer> classes = new HashSet<Integer>();
            for (int i = 0; i < realOutcomes.columns(); ++i) {
                classes.add(i);
            }
            this.confusion = new ConfusionMatrix(new ArrayList(classes));
        }
        if (realOutcomes.length() != guesses.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        for (int i = 0; i < realOutcomes.rows(); ++i) {
            INDArray currRow = realOutcomes.getRow(i);
            INDArray guessRow = guesses.getRow(i);
            double max = currRow.getDouble(0);
            int currMax = 0;
            for (int col = 1; col < currRow.columns(); ++col) {
                if (!(currRow.getDouble(col) > max)) continue;
                max = currRow.getDouble(col);
                currMax = col;
            }
            double max2 = guessRow.getDouble(0);
            int guessMax = 0;
            for (int col = 1; col < guessRow.columns(); ++col) {
                if (!(guessRow.getDouble(col) > max2)) continue;
                max2 = guessRow.getDouble(col);
                guessMax = col;
            }
            this.addToConfusion(currMax, guessMax);
            if (currMax == guessMax) {
                this.incrementTruePositives(guessMax);
                for (Integer clazz : this.confusion.getClasses()) {
                    if (clazz == guessMax) continue;
                    this.trueNegatives.incrementCount(clazz, 1.0);
                }
                continue;
            }
            this.incrementFalseNegatives(currMax);
            this.incrementFalsePositives(guessMax);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == guessMax || clazz == currMax) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        }
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted) {
        if (labels.rank() == 2 && predicted.rank() == 2) {
            this.eval(labels, predicted);
        }
        if (labels.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: labels are not rank 3 (rank=" + labels.rank() + ")");
        }
        if (!Arrays.equals(labels.shape(), predicted.shape())) {
            throw new IllegalArgumentException("Labels and predicted have different shapes: labels=" + Arrays.toString(labels.shape()) + ", predicted=" + Arrays.toString(predicted.shape()));
        }
        if (labels.ordering() == 'f') {
            labels = Shape.toOffsetZeroCopy((INDArray)labels, (char)'c');
        }
        if (predicted.ordering() == 'f') {
            predicted = Shape.toOffsetZeroCopy((INDArray)predicted, (char)'c');
        }
        int[] shape = labels.shape();
        labels = labels.permute(new int[]{0, 2, 1});
        labels = labels.reshape(shape[0] * shape[2], shape[1]);
        predicted = predicted.permute(new int[]{0, 2, 1});
        predicted = predicted.reshape(shape[0] * shape[2], shape[1]);
        this.eval(labels, predicted);
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted, INDArray outputMask) {
        int totalOutputExamples = outputMask.sumNumber().intValue();
        int outSize = labels.size(1);
        INDArray labels2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        INDArray predicted2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        int rowCount = 0;
        for (int ex = 0; ex < outputMask.size(0); ++ex) {
            for (int t = 0; t < outputMask.size(1); ++t) {
                if (outputMask.getDouble(ex, t) == 0.0) continue;
                labels2d.putRow(rowCount, labels.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                predicted2d.putRow(rowCount, predicted.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                ++rowCount;
            }
        }
        this.eval(labels2d, predicted2d);
    }

    public void eval(int predictedIdx, int actualIdx) {
        ++this.numRowCounter;
        if (this.confusion == null) {
            throw new UnsupportedOperationException("Cannot evaluate single example without initializing confusion matrix first");
        }
        this.addToConfusion(predictedIdx, actualIdx);
        if (predictedIdx == actualIdx) {
            this.incrementTruePositives(predictedIdx);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == predictedIdx) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        } else {
            this.incrementFalseNegatives(actualIdx);
            this.incrementFalsePositives(predictedIdx);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == predictedIdx || clazz == actualIdx) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        }
    }

    public String stats() {
        StringBuilder builder = new StringBuilder().append("\n");
        StringBuilder warnings = new StringBuilder();
        List<Integer> classes = this.confusion.getClasses();
        for (Integer clazz : classes) {
            String actual = this.resolveLabelForClass(clazz);
            for (Integer clazz2 : classes) {
                int count = this.confusion.getCount(clazz, clazz2);
                if (count == 0) continue;
                String expected = this.resolveLabelForClass(clazz2);
                builder.append(String.format("Examples labeled as %s classified by model as %s: %d times\n", actual, expected, count));
            }
            if (this.truePositives.getCount(clazz) != 0.0) continue;
            if (this.falsePositives.getCount(clazz) == 0.0) {
                warnings.append(String.format("Warning: class %s was never predicted by the model. This class was excluded from the average precision\n", actual));
            }
            if (this.falseNegatives.getCount(clazz) != 0.0) continue;
            warnings.append(String.format("Warning: class %s has never appeared as a true label. This class was excluded from the average recall\n", actual));
        }
        builder.append("\n");
        builder.append((CharSequence)warnings);
        DecimalFormat df = new DecimalFormat("#.####");
        builder.append("\n==========================Scores========================================");
        builder.append("\n Accuracy:  ").append(df.format(this.accuracy()));
        builder.append("\n Precision: ").append(df.format(this.precision()));
        builder.append("\n Recall:    ").append(df.format(this.recall()));
        builder.append("\n F1 Score:  ").append(df.format(this.f1()));
        builder.append("\n========================================================================");
        return builder.toString();
    }

    private String resolveLabelForClass(Integer clazz) {
        String label = this.labelsMap.get(clazz);
        if (label == null || label.isEmpty()) {
            label = clazz.toString();
        }
        return label;
    }

    public double precision(Integer classLabel) {
        return this.precision(classLabel, 0.0);
    }

    public double precision(Integer classLabel, double edgeCase) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fpCount = this.falsePositives.getCount(classLabel);
        if (tpCount == 0.0 && fpCount == 0.0) {
            return edgeCase;
        }
        return tpCount / (tpCount + fpCount);
    }

    public double precision() {
        double precisionAcc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double precision = this.precision(classLabel, -1.0);
            if (precision == -1.0) continue;
            precisionAcc += this.precision(classLabel);
            ++classCount;
        }
        return precisionAcc / (double)classCount;
    }

    public double recall(Integer classLabel) {
        return this.recall(classLabel, 0.0);
    }

    public double recall(Integer classLabel, double edgeCase) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fnCount = this.falseNegatives.getCount(classLabel);
        if (tpCount == 0.0 && fnCount == 0.0) {
            return edgeCase;
        }
        return tpCount / (tpCount + fnCount);
    }

    public double recall() {
        double recallAcc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double recall = this.recall(classLabel, -1.0);
            if (recall == -1.0) continue;
            recallAcc += this.recall(classLabel);
            ++classCount;
        }
        return recallAcc / (double)classCount;
    }

    public double f1(Integer classLabel) {
        double precision = this.precision(classLabel);
        double recall = this.recall(classLabel);
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double accuracy() {
        return this.truePositives() / this.getNumRowCounter();
    }

    public double truePositives() {
        return this.truePositives.totalCount();
    }

    public double trueNegatives() {
        return this.trueNegatives.totalCount();
    }

    public double falsePositives() {
        return this.falsePositives.totalCount();
    }

    public double falseNegatives() {
        return this.falseNegatives.totalCount();
    }

    public double negative() {
        return this.trueNegatives() + this.falsePositives();
    }

    public double positive() {
        return this.truePositives() + this.falseNegatives();
    }

    public void incrementTruePositives(Integer classLabel) {
        this.truePositives.incrementCount(classLabel, 1.0);
    }

    public void incrementTrueNegatives(Integer classLabel) {
        this.trueNegatives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalseNegatives(Integer classLabel) {
        this.falseNegatives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalsePositives(Integer classLabel) {
        this.falsePositives.incrementCount(classLabel, 1.0);
    }

    public void addToConfusion(Integer real, Integer guess) {
        this.confusion.add(real, guess);
    }

    public int classCount(Integer clazz) {
        return this.confusion.getActualTotal(clazz);
    }

    public double getNumRowCounter() {
        return this.numRowCounter;
    }

    public String getClassLabel(Integer clazz) {
        return this.labelsMap.get(clazz);
    }

    public ConfusionMatrix<Integer> getConfusionMatrix() {
        return this.confusion;
    }
}

