package hex;

import hex.ModelMetricsRegression;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:hex/ModelMetricsRegressionCoxPH.class */
public class ModelMetricsRegressionCoxPH extends ModelMetricsRegression {
    private double _concordance;
    private long _concordant;
    private long _discordant;
    private long _tied_y;

    /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$MetricBuilderRegressionCoxPH.class */
    public static class MetricBuilderRegressionCoxPH<T extends MetricBuilderRegressionCoxPH<T>> extends ModelMetricsRegression.MetricBuilderRegression<T> {
        private final String startVecName;
        private final String stopVecName;
        private final boolean isStratified;
        private final String[] stratifyBy;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$MetricBuilderRegressionCoxPH$Stats.class */
        public static class Stats {
            final long ntotals;
            final long nNotNaN;
            final long nconcordant;
            final long nties;

            Stats() {
                this(0L, 0L, 0L, 0L);
            }

            Stats(long j, long j2, long j3, long j4) {
                this.ntotals = j;
                this.nNotNaN = j2;
                this.nconcordant = j3;
                this.nties = j4;
            }

            double c() {
                return (this.nconcordant + (0.5d * this.nties)) / this.ntotals;
            }

            long discordant() {
                return (this.ntotals - this.nconcordant) - this.nties;
            }

            Stats plus(Stats stats) {
                return new Stats(this.ntotals + stats.ntotals, this.nNotNaN + stats.nNotNaN, this.nconcordant + stats.nconcordant, this.nties + stats.nties);
            }
        }

        public MetricBuilderRegressionCoxPH(String str, String str2, boolean z, String[] strArr) {
            this.startVecName = str;
            this.stopVecName = str2;
            this.isStratified = z;
            this.stratifyBy = strArr;
        }

        @Override // hex.ModelMetricsRegression.MetricBuilderRegression, hex.ModelMetrics.MetricBuilder
        public ModelMetricsRegressionCoxPH makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
            ModelMetricsRegression computeModelMetrics = super.computeModelMetrics(model, frame, frame2, frame3);
            Stats concordance = concordance(model, frame, frame2, frame3);
            ModelMetricsRegressionCoxPH modelMetricsRegressionCoxPH = new ModelMetricsRegressionCoxPH(model, frame, this._count, computeModelMetrics.mse(), weightedSigma(), computeModelMetrics.mae(), computeModelMetrics.rmsle(), computeModelMetrics.mean_residual_deviance(), this._customMetric, concordance.c(), concordance.nconcordant, concordance.discordant(), concordance.nties);
            if (model != null) {
                model.addModelMetrics(modelMetricsRegressionCoxPH);
            }
            return modelMetricsRegressionCoxPH;
        }

        private Stats concordance(Model model, Frame frame, Frame frame2, Frame frame3) {
            return concordance(frame2.vec(this.startVecName), frame2.vec(this.stopVecName), frame2.lastVec(), this.isStratified ? (List) Arrays.asList(this.stratifyBy).stream().map(str -> {
                return frame.vec(str);
            }).collect(Collectors.toList()) : Collections.emptyList(), frame3.lastVec());
        }

        private static boolean isValidComparison(double d, double d2, boolean z, boolean z2) {
            if (d == d2) {
                return z != z2;
            }
            if (z && z2) {
                return true;
            }
            if (!z || d >= d2) {
                return z2 && d2 < d;
            }
            return true;
        }

        static Stats concordance(Vec vec, Vec vec2, Vec vec3, List<Vec> list, Vec vec4) {
            Vec.Reader reader;
            long length = vec4.length();
            if (null == vec) {
                reader = null;
            } else {
                vec.getClass();
                reader = new Vec.Reader();
            }
            vec2.getClass();
            Vec.Reader reader2 = new Vec.Reader();
            vec3.getClass();
            Vec.Reader reader3 = new Vec.Reader();
            List list2 = (List) list.stream().map(vec5 -> {
                vec5.getClass();
                return new Vec.Reader();
            }).collect(Collectors.toList());
            vec4.getClass();
            return concordanceStats(reader, reader2, reader3, list2, new Vec.Reader(), length);
        }

        private static Stats concordanceStats(Vec.Reader reader, Vec.Reader reader2, Vec.Reader reader3, List<Vec.Reader> list, Vec.Reader reader4, long j) {
            if ($assertionsDisabled || (0 <= j && j <= 2147483647L)) {
                return (Stats) ((Map) IntStream.range(0, (int) j).boxed().collect(Collectors.groupingBy(num -> {
                    return (List) list.stream().map(reader5 -> {
                        return Long.valueOf(reader5.at8(num.intValue()));
                    }).collect(Collectors.toList());
                }))).values().stream().map(list2 -> {
                    return statsForAStrata(reader, reader2, reader3, reader4, list2);
                }).reduce(new Stats(), (v0, v1) -> {
                    return v0.plus(v1);
                });
            }
            throw new AssertionError();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static Stats statsForAStrata(Vec.Reader reader, Vec.Reader reader2, Vec.Reader reader3, Vec.Reader reader4, List<Integer> list) {
            long j = 0;
            long j2 = 0;
            long j3 = 0;
            long j4 = 0;
            Iterator<Integer> it = list.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                Iterator<Integer> it2 = list.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    if (intValue2 > intValue) {
                        double at = reader2.at(intValue) - (reader != null ? reader.at(intValue) : 0.0d);
                        double at2 = reader2.at(intValue2) - (reader != null ? reader.at(intValue2) : 0.0d);
                        long at8 = reader3.at8(intValue);
                        long at82 = reader3.at8(intValue2);
                        double at3 = reader4.at(intValue);
                        double at4 = reader4.at(intValue2);
                        boolean z = 0 == at8;
                        boolean z2 = 0 == at82;
                        if (!Double.isNaN(at) && !Double.isNaN(at2) && !Double.isNaN(at3) && !Double.isNaN(at4)) {
                            j2++;
                            if (isValidComparison(at, at2, !z, !z2)) {
                                j++;
                                if (at3 == at4) {
                                    j4++;
                                } else if (at3 > at4) {
                                    if (at < at2 || (at == at2 && !z && z2)) {
                                        j3++;
                                    }
                                } else if (at > at2 || (at == at2 && z && !z2)) {
                                    j3++;
                                }
                            }
                        }
                    }
                }
            }
            return new Stats(j, j2, j3, j4);
        }

        static {
            $assertionsDisabled = !ModelMetricsRegressionCoxPH.class.desiredAssertionStatus();
        }
    }

    public double concordance() {
        return this._concordance;
    }

    public long concordant() {
        return this._concordant;
    }

    public long discordant() {
        return this._discordant;
    }

    public long tiedY() {
        return this._tied_y;
    }

    public ModelMetricsRegressionCoxPH(Model model, Frame frame, long j, double d, double d2, double d3, double d4, double d5, CustomMetric customMetric, double d6, long j2, long j3, long j4) {
        super(model, frame, j, d, d2, d3, d4, d5, customMetric);
        this._concordance = d6;
        this._concordant = j2;
        this._discordant = j3;
        this._tied_y = j4;
    }

    public static ModelMetricsRegressionCoxPH getFromDKV(Model model, Frame frame) {
        ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, frame);
        if (fromDKV instanceof ModelMetricsRegressionCoxPH) {
            return (ModelMetricsRegressionCoxPH) fromDKV;
        }
        throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + fromDKV.getClass());
    }

    @Override // hex.ModelMetricsRegression, hex.ModelMetricsSupervised, hex.ModelMetrics
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (Double.isNaN(this._concordance)) {
            sb.append(" concordance: N/A\n");
        } else {
            sb.append(" concordance: " + ((float) this._concordance) + "\n");
        }
        sb.append(" concordant: " + this._concordant + "\n");
        sb.append(" discordant: " + this._discordant + "\n");
        sb.append(" tied.y: " + this._tied_y + "\n");
        return sb.toString();
    }
}
