/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.CustomMetric;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;

public class ModelMetricsRegressionCoxPH
extends ModelMetricsRegression {
    private double _concordance;
    private long _concordant;
    private long _discordant;
    private long _tied_y;

    public double concordance() {
        return this._concordance;
    }

    public long concordant() {
        return this._concordant;
    }

    public long discordant() {
        return this._discordant;
    }

    public long tiedY() {
        return this._tied_y;
    }

    public ModelMetricsRegressionCoxPH(Model model, Frame frame, long nobs, double mse, double sigma, double mae, double rmsle, double meanResidualDeviance, CustomMetric customMetric, double concordance, long concordant, long discordant, long tied_y) {
        super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric);
        this._concordance = concordance;
        this._concordant = concordant;
        this._discordant = discordant;
        this._tied_y = tied_y;
    }

    public static ModelMetricsRegressionCoxPH getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsRegressionCoxPH)) {
            throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
        }
        return (ModelMetricsRegressionCoxPH)mm;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (!Double.isNaN(this._concordance)) {
            sb.append(" concordance: " + (float)this._concordance + "\n");
        } else {
            sb.append(" concordance: N/A\n");
        }
        sb.append(" concordant: " + this._concordant + "\n");
        sb.append(" discordant: " + this._discordant + "\n");
        sb.append(" tied.y: " + this._tied_y + "\n");
        return sb.toString();
    }

    public static class MetricBuilderRegressionCoxPH<T extends MetricBuilderRegressionCoxPH<T>>
    extends ModelMetricsRegression.MetricBuilderRegression<T> {
        private final String startVecName;
        private final String stopVecName;
        private final boolean isStratified;
        private final String[] stratifyBy;

        public MetricBuilderRegressionCoxPH(String startVecName, String stopVecName, boolean isStratified, String[] stratifyByName) {
            this.startVecName = startVecName;
            this.stopVecName = stopVecName;
            this.isStratified = isStratified;
            this.stratifyBy = stratifyByName;
        }

        @Override
        public ModelMetricsRegressionCoxPH makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
            ModelMetricsRegression modelMetricsRegression = super.computeModelMetrics(m, f, adaptedFrame, preds);
            Stats stats = this.concordance(m, f, adaptedFrame, preds);
            ModelMetricsRegressionCoxPH mm = new ModelMetricsRegressionCoxPH(m, f, this._count, modelMetricsRegression.mse(), this.weightedSigma(), modelMetricsRegression.mae(), modelMetricsRegression.rmsle(), modelMetricsRegression.mean_residual_deviance(), this._customMetric, stats.c(), stats.nconcordant, stats.discordant(), stats.nties);
            if (m != null) {
                m.addModelMetrics(mm);
            }
            return mm;
        }

        private Stats concordance(Model m, Frame fr, Frame adaptFrm, Frame scored) {
            Vec startVec = adaptFrm.vec(this.startVecName);
            Vec stopVec = adaptFrm.vec(this.stopVecName);
            Vec statusVec = adaptFrm.lastVec();
            Vec estimateVec = scored.lastVec();
            List<Vec> strataVecs = this.isStratified ? Arrays.asList(this.stratifyBy).stream().map(s -> fr.vec((String)s)).collect(Collectors.toList()) : Collections.emptyList();
            return MetricBuilderRegressionCoxPH.concordance(startVec, stopVec, statusVec, strataVecs, estimateVec);
        }

        private static boolean isValidComparison(double time1, double time2, boolean event1, boolean event2) {
            if (time1 == time2) {
                return event1 != event2;
            }
            if (event1 && event2) {
                return true;
            }
            if (event1 && time1 < time2) {
                return true;
            }
            return event2 && time2 < time1;
        }

        static Stats concordance(Vec startVec, Vec stopVec, Vec eventVec, List<Vec> strataVecs, Vec estimateVec) {
            long length = estimateVec.length();
            Stats stats = MetricBuilderRegressionCoxPH.concordanceStats(null == startVec ? null : new Vec.Reader(startVec), new Vec.Reader(stopVec), new Vec.Reader(eventVec), strataVecs.stream().map(it -> new Vec.Reader((Vec)it)).collect(Collectors.toList()), new Vec.Reader(estimateVec), length);
            return stats;
        }

        private static Stats concordanceStats(Vec.Reader startVec, Vec.Reader stopVec, Vec.Reader eventVec, List<Vec.Reader> strataVecs, Vec.Reader estimateVec, long length) {
            assert (0L <= length && length <= Integer.MAX_VALUE);
            Stream<Integer> allIndexes = IntStream.range(0, (int)length).boxed();
            Map<List, List<Integer>> byStrata = allIndexes.collect(Collectors.groupingBy(i -> strataVecs.stream().map(v -> v.at8(i.intValue())).collect(Collectors.toList())));
            return byStrata.values().stream().map(indexes -> MetricBuilderRegressionCoxPH.statsForAStrata(startVec, stopVec, eventVec, estimateVec, indexes)).reduce(new Stats(), Stats::plus);
        }

        private static Stats statsForAStrata(Vec.Reader startVec, Vec.Reader stopVec, Vec.Reader eventVec, Vec.Reader estimateVec, List<Integer> indexes) {
            long ntotals = 0L;
            long nNotNaN = 0L;
            long nconcordant = 0L;
            long nties = 0L;
            for (int i : indexes) {
                for (int j : indexes) {
                    boolean censored2;
                    if (j <= i) continue;
                    double t1 = stopVec.at(i) - (startVec != null ? startVec.at(i) : 0.0);
                    double t2 = stopVec.at(j) - (startVec != null ? startVec.at(j) : 0.0);
                    long event1 = eventVec.at8(i);
                    long event2 = eventVec.at8(j);
                    double estimate1 = estimateVec.at(i);
                    double estimate2 = estimateVec.at(j);
                    boolean censored1 = 0L == event1;
                    boolean bl = censored2 = 0L == event2;
                    if (Double.isNaN(t1) || Double.isNaN(t2) || Double.isNaN(estimate1) || Double.isNaN(estimate2)) continue;
                    ++nNotNaN;
                    if (!MetricBuilderRegressionCoxPH.isValidComparison(t1, t2, !censored1, !censored2)) continue;
                    ++ntotals;
                    if (estimate1 == estimate2) {
                        ++nties;
                        continue;
                    }
                    if (estimate1 > estimate2) {
                        if (!(t1 < t2) && (t1 != t2 || censored1 || !censored2)) continue;
                        ++nconcordant;
                        continue;
                    }
                    if (!(t1 > t2) && (t1 != t2 || !censored1 || censored2)) continue;
                    ++nconcordant;
                }
            }
            return new Stats(ntotals, nNotNaN, nconcordant, nties);
        }

        static class Stats {
            final long ntotals;
            final long nNotNaN;
            final long nconcordant;
            final long nties;

            Stats() {
                this(0L, 0L, 0L, 0L);
            }

            Stats(long ntotals, long nNotNaN, long nconcordant, long nties) {
                this.ntotals = ntotals;
                this.nNotNaN = nNotNaN;
                this.nconcordant = nconcordant;
                this.nties = nties;
            }

            double c() {
                return ((double)this.nconcordant + 0.5 * (double)this.nties) / (double)this.ntotals;
            }

            long discordant() {
                return this.ntotals - this.nconcordant - this.nties;
            }

            Stats plus(Stats s2) {
                return new Stats(this.ntotals + s2.ntotals, this.nNotNaN + s2.nNotNaN, this.nconcordant + s2.nconcordant, this.nties + s2.nties);
            }
        }
    }
}

