package org.deeplearning4j.eval.curves;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/eval/curves/PrecisionRecallCurve.class */
public class PrecisionRecallCurve extends BaseCurve {
    private double[] threshold;
    private double[] precision;
    private double[] recall;
    private int[] tpCount;
    private int[] fpCount;
    private int[] fnCount;
    private int totalCount;
    private Double area;

    /* loaded from: input_file:org/deeplearning4j/eval/curves/PrecisionRecallCurve$Confusion.class */
    public static class Confusion {
        private final Point point;
        private final int tpCount;
        private final int fpCount;
        private final int fnCount;
        private final int tnCount;

        public Confusion(Point point, int i, int i2, int i3, int i4) {
            this.point = point;
            this.tpCount = i;
            this.fpCount = i2;
            this.fnCount = i3;
            this.tnCount = i4;
        }

        public Point getPoint() {
            return this.point;
        }

        public int getTpCount() {
            return this.tpCount;
        }

        public int getFpCount() {
            return this.fpCount;
        }

        public int getFnCount() {
            return this.fnCount;
        }

        public int getTnCount() {
            return this.tnCount;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Confusion)) {
                return false;
            }
            Confusion confusion = (Confusion) obj;
            if (!confusion.canEqual(this)) {
                return false;
            }
            Point point = getPoint();
            Point point2 = confusion.getPoint();
            if (point == null) {
                if (point2 != null) {
                    return false;
                }
            } else if (!point.equals(point2)) {
                return false;
            }
            return getTpCount() == confusion.getTpCount() && getFpCount() == confusion.getFpCount() && getFnCount() == confusion.getFnCount() && getTnCount() == confusion.getTnCount();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Confusion;
        }

        public int hashCode() {
            Point point = getPoint();
            return (((((((((1 * 59) + (point == null ? 43 : point.hashCode())) * 59) + getTpCount()) * 59) + getFpCount()) * 59) + getFnCount()) * 59) + getTnCount();
        }

        public String toString() {
            return "PrecisionRecallCurve.Confusion(point=" + getPoint() + ", tpCount=" + getTpCount() + ", fpCount=" + getFpCount() + ", fnCount=" + getFnCount() + ", tnCount=" + getTnCount() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/eval/curves/PrecisionRecallCurve$Point.class */
    public static class Point {
        private final int idx;
        private final double threshold;
        private final double precision;
        private final double recall;

        public Point(int i, double d, double d2, double d3) {
            this.idx = i;
            this.threshold = d;
            this.precision = d2;
            this.recall = d3;
        }

        public int getIdx() {
            return this.idx;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public double getPrecision() {
            return this.precision;
        }

        public double getRecall() {
            return this.recall;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Point)) {
                return false;
            }
            Point point = (Point) obj;
            return point.canEqual(this) && getIdx() == point.getIdx() && Double.compare(getThreshold(), point.getThreshold()) == 0 && Double.compare(getPrecision(), point.getPrecision()) == 0 && Double.compare(getRecall(), point.getRecall()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Point;
        }

        public int hashCode() {
            int idx = (1 * 59) + getIdx();
            long doubleToLongBits = Double.doubleToLongBits(getThreshold());
            int i = (idx * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getPrecision());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getRecall());
            return (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
        }

        public String toString() {
            return "PrecisionRecallCurve.Point(idx=" + getIdx() + ", threshold=" + getThreshold() + ", precision=" + getPrecision() + ", recall=" + getRecall() + ")";
        }
    }

    public PrecisionRecallCurve(@JsonProperty("threshold") double[] dArr, @JsonProperty("precision") double[] dArr2, @JsonProperty("recall") double[] dArr3, @JsonProperty("tpCount") int[] iArr, @JsonProperty("fpCount") int[] iArr2, @JsonProperty("fnCount") int[] iArr3, @JsonProperty("totalCount") int i) {
        this.threshold = dArr;
        this.precision = dArr2;
        this.recall = dArr3;
        this.tpCount = iArr;
        this.fpCount = iArr2;
        this.fnCount = iArr3;
        this.totalCount = i;
    }

    @Override // org.deeplearning4j.eval.curves.BaseCurve
    public int numPoints() {
        return this.threshold.length;
    }

    @Override // org.deeplearning4j.eval.curves.BaseCurve
    public double[] getX() {
        return this.recall;
    }

    @Override // org.deeplearning4j.eval.curves.BaseCurve
    public double[] getY() {
        return this.precision;
    }

    @Override // org.deeplearning4j.eval.curves.BaseCurve
    public String getTitle() {
        return "Precision-Recall Curve (Area=" + format(calculateAUPRC(), 4) + ")";
    }

    public double getThreshold(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.threshold.length, "Invalid index: " + i);
        return this.threshold[i];
    }

    public double getPrecision(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.precision.length, "Invalid index: " + i);
        return this.precision[i];
    }

    public double getRecall(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.recall.length, "Invalid index: " + i);
        return this.recall[i];
    }

    public double calculateAUPRC() {
        if (this.area != null) {
            return this.area.doubleValue();
        }
        this.area = Double.valueOf(calculateArea());
        return this.area.doubleValue();
    }

    public Point getPointAtThreshold(double d) {
        int binarySearch = Arrays.binarySearch(this.threshold, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 1;
        }
        return new Point(binarySearch, this.threshold[binarySearch], this.precision[binarySearch], this.recall[binarySearch]);
    }

    public Point getPointAtPrecision(double d) {
        for (int i = 0; i < this.precision.length; i++) {
            if (this.precision[i] >= d) {
                return new Point(i, this.threshold[i], this.precision[i], this.recall[i]);
            }
        }
        int length = this.threshold.length - 1;
        return new Point(length, this.threshold[length], this.precision[length], this.recall[length]);
    }

    public Point getPointAtRecall(double d) {
        Point point = null;
        for (int length = this.recall.length - 1; length >= 0; length--) {
            if (this.recall[length] >= d && (point == null || (this.recall[length] == point.getRecall() && this.precision[length] >= point.getPrecision()))) {
                point = new Point(length, this.threshold[length], this.precision[length], this.recall[length]);
            }
        }
        if (point == null) {
            point = new Point(0, this.threshold[0], this.precision[0], this.recall[0]);
        }
        return point;
    }

    public Confusion getConfusionMatrixAtThreshold(double d) {
        Point pointAtThreshold = getPointAtThreshold(d);
        int i = pointAtThreshold.idx;
        return new Confusion(pointAtThreshold, this.tpCount[i], this.fpCount[i], this.fnCount[i], this.totalCount - ((this.tpCount[i] + this.fpCount[i]) + this.fnCount[i]));
    }

    public Confusion getConfusionMatrixAtPoint(int i) {
        return getConfusionMatrixAtThreshold(this.threshold[i]);
    }

    public static PrecisionRecallCurve fromJson(String str) {
        return (PrecisionRecallCurve) fromJson(str, PrecisionRecallCurve.class);
    }

    public static PrecisionRecallCurve fromYaml(String str) {
        return (PrecisionRecallCurve) fromYaml(str, PrecisionRecallCurve.class);
    }

    public double[] getThreshold() {
        return this.threshold;
    }

    public double[] getPrecision() {
        return this.precision;
    }

    public double[] getRecall() {
        return this.recall;
    }

    public int[] getTpCount() {
        return this.tpCount;
    }

    public int[] getFpCount() {
        return this.fpCount;
    }

    public int[] getFnCount() {
        return this.fnCount;
    }

    public int getTotalCount() {
        return this.totalCount;
    }

    public Double getArea() {
        return this.area;
    }

    public void setThreshold(double[] dArr) {
        this.threshold = dArr;
    }

    public void setPrecision(double[] dArr) {
        this.precision = dArr;
    }

    public void setRecall(double[] dArr) {
        this.recall = dArr;
    }

    public void setTpCount(int[] iArr) {
        this.tpCount = iArr;
    }

    public void setFpCount(int[] iArr) {
        this.fpCount = iArr;
    }

    public void setFnCount(int[] iArr) {
        this.fnCount = iArr;
    }

    public void setTotalCount(int i) {
        this.totalCount = i;
    }

    public void setArea(Double d) {
        this.area = d;
    }

    public String toString() {
        return "PrecisionRecallCurve(threshold=" + Arrays.toString(getThreshold()) + ", precision=" + Arrays.toString(getPrecision()) + ", recall=" + Arrays.toString(getRecall()) + ", tpCount=" + Arrays.toString(getTpCount()) + ", fpCount=" + Arrays.toString(getFpCount()) + ", fnCount=" + Arrays.toString(getFnCount()) + ", totalCount=" + getTotalCount() + ", area=" + getArea() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof PrecisionRecallCurve)) {
            return false;
        }
        PrecisionRecallCurve precisionRecallCurve = (PrecisionRecallCurve) obj;
        return precisionRecallCurve.canEqual(this) && Arrays.equals(getThreshold(), precisionRecallCurve.getThreshold()) && Arrays.equals(getPrecision(), precisionRecallCurve.getPrecision()) && Arrays.equals(getRecall(), precisionRecallCurve.getRecall()) && Arrays.equals(getTpCount(), precisionRecallCurve.getTpCount()) && Arrays.equals(getFpCount(), precisionRecallCurve.getFpCount()) && Arrays.equals(getFnCount(), precisionRecallCurve.getFnCount()) && getTotalCount() == precisionRecallCurve.getTotalCount();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof PrecisionRecallCurve;
    }

    public int hashCode() {
        return (((((((((((((1 * 59) + Arrays.hashCode(getThreshold())) * 59) + Arrays.hashCode(getPrecision())) * 59) + Arrays.hashCode(getRecall())) * 59) + Arrays.hashCode(getTpCount())) * 59) + Arrays.hashCode(getFpCount())) * 59) + Arrays.hashCode(getFnCount())) * 59) + getTotalCount();
    }
}
