package hex;

import hex.Model;
import java.util.ArrayList;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.util.FrameUtils;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/PartialDependence.class */
public class PartialDependence extends Lockable<PartialDependence> {
    public final transient Job _job;
    public Key<Model> _model_id;
    public Key<Frame> _frame_id;
    public long _row_index;
    public String[] _cols;
    public ArrayList<String> _cols_1d_2d;
    public int _weight_column_index;
    public boolean _add_missing_na;
    public int _nbins;
    public String[] _targets;
    public TwoDimTable[] _partial_dependence_data;
    public double[] _user_splits;
    public double[][] _user_split_per_col;
    public int[] _num_user_splits;
    public String[] _user_cols;
    public boolean _user_splits_present;
    public String[][] _col_pairs_2dpdp;
    public int _num_2D_pairs;
    public int _num_1D;
    public int _predictor_column;
    public int[] _predictor_columns;

    /* loaded from: input_file:hex/PartialDependence$PartialDependenceDriver.class */
    private class PartialDependenceDriver extends H2O.H2OCountedCompleter<PartialDependenceDriver> {
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/PartialDependence$PartialDependenceDriver$CalculatePdpPerBin.class */
        public class CalculatePdpPerBin extends H2O.H2OCountedCompleter<CalculatePdpPerBin> {
            final String _col;
            final String _col2;
            final double _value;
            final double _value2;
            final boolean _workOn2D;
            final int _pdp_row_index;
            final boolean _col1_cat;
            final boolean _col2_cat;
            final double[] _meanResponse;
            final double[] _stddevResponse;
            final double[] _stdErrorOfTheMeanResponse;
            final int _predictorColumn;

            CalculatePdpPerBin(String str, String str2, double d, double d2, boolean z, boolean z2, int i, boolean z3, double[] dArr, double[] dArr2, double[] dArr3, int i2) {
                this._col = str;
                this._col2 = str2;
                this._value = d;
                this._value2 = d2;
                this._workOn2D = z3;
                this._pdp_row_index = i;
                this._col1_cat = z;
                this._col2_cat = z2;
                this._meanResponse = dArr;
                this._stddevResponse = dArr2;
                this._stdErrorOfTheMeanResponse = dArr3;
                this._predictorColumn = i2;
            }

            @Override // water.H2O.H2OCountedCompleter
            public void compute2() {
                Frame frame = PartialDependence.this._row_index >= 0 ? Rapids.exec("(rows " + PartialDependence.this._frame_id + "  " + PartialDependence.this._row_index + ")").getFrame() : PartialDependence.this._frame_id.get();
                Frame frame2 = new Frame(frame.names(), frame.vecs());
                Vec makeCon = frame2.remove(this._col).makeCon(this._value);
                if (this._col1_cat) {
                    makeCon.setDomain(frame.vec(this._col).domain());
                }
                frame2.add(this._col, makeCon);
                Vec vec = null;
                if (this._workOn2D) {
                    vec = frame2.remove(this._col2).makeCon(this._value2);
                    if (this._col2_cat) {
                        vec.setDomain(frame.vec(this._col2).domain());
                    }
                    frame2.add(this._col2, vec);
                }
                Keyed keyed = null;
                try {
                    Frame score = PartialDependence.this._model_id.get().score(frame2, Key.make().toString(), PartialDependence.this._job, false);
                    if (score == null || score.numRows() == 0) {
                        this._meanResponse[this._pdp_row_index] = Double.NaN;
                        this._stddevResponse[this._pdp_row_index] = Double.NaN;
                        this._stdErrorOfTheMeanResponse[this._pdp_row_index] = Double.NaN;
                    } else {
                        FrameUtils.CalculateWeightMeanSTD weightedStat = PartialDependence.this._weight_column_index >= 0 ? PartialDependenceDriver.this.getWeightedStat(frame, score, this._predictorColumn) : null;
                        this._meanResponse[this._pdp_row_index] = PartialDependence.this._weight_column_index >= 0 ? weightedStat.getWeightedMean() : score.vec(this._predictorColumn).mean();
                        this._stddevResponse[this._pdp_row_index] = PartialDependence.this._weight_column_index >= 0 ? weightedStat.getWeightedSigma() : score.vec(this._predictorColumn).sigma();
                        this._stdErrorOfTheMeanResponse[this._pdp_row_index] = this._stddevResponse[this._pdp_row_index] / Math.sqrt(score.numRows());
                    }
                    if (score != null) {
                        score.remove();
                    }
                    makeCon.remove();
                    if (vec != null) {
                        vec.remove();
                    }
                    if (PartialDependence.this._row_index >= 0) {
                        frame.remove();
                    }
                    tryComplete();
                } catch (Throwable th) {
                    if (0 != 0) {
                        keyed.remove();
                    }
                    throw th;
                }
            }
        }

        private PartialDependenceDriver() {
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            if (!$assertionsDisabled && PartialDependence.this._job == null) {
                throw new AssertionError();
            }
            Frame frame = PartialDependence.this._frame_id.get();
            int i = PartialDependence.this._num_1D + PartialDependence.this._num_2D_pairs;
            PartialDependence.this._partial_dependence_data = new TwoDimTable[i];
            int i2 = 0;
            int i3 = 0;
            while (i3 < i) {
                boolean z = i3 < PartialDependence.this._num_1D;
                String str = z ? PartialDependence.this._cols[i2] : PartialDependence.this._col_pairs_2dpdp[i2 - PartialDependence.this._num_1D][0];
                String str2 = z ? null : PartialDependence.this._col_pairs_2dpdp[i2 - PartialDependence.this._num_1D][1];
                int length = i3 % PartialDependence.this._predictor_columns.length;
                Object[] objArr = new Object[1];
                objArr[0] = "Computing partial dependence of model on '" + str + "'" + (PartialDependence.this._targets == null ? "." : " and class " + PartialDependence.this._targets[length] + ".");
                Log.debug(objArr);
                double[] extractColValues = PartialDependence.this.extractColValues(str, PartialDependence.this._nbins, frame.vec(str));
                double[] extractColValues2 = z ? null : PartialDependence.this.extractColValues(str2, PartialDependence.this._nbins, frame.vec(str2));
                Futures futures = new Futures();
                int length2 = z ? extractColValues.length : extractColValues.length * extractColValues2.length;
                double[] dArr = new double[length2];
                double[] dArr2 = new double[length2];
                double[] dArr3 = new double[length2];
                boolean isCategorical = frame.vec(str).isCategorical();
                boolean isCategorical2 = z ? false : frame.vec(str2).isCategorical();
                if (z) {
                    for (int i4 = 0; i4 < extractColValues.length; i4++) {
                        futures.add(H2O.submitTask(new CalculatePdpPerBin(str, str2, extractColValues[i4], -1.0d, isCategorical, isCategorical2, i4, false, dArr, dArr2, dArr3, PartialDependence.this._predictor_columns[length])));
                    }
                } else {
                    int length3 = extractColValues.length;
                    int length4 = extractColValues2.length;
                    int i5 = length3 * length4;
                    for (int i6 = 0; i6 < i5; i6++) {
                        futures.add(H2O.submitTask(new CalculatePdpPerBin(str, str2, extractColValues[i6 / length4], extractColValues2[i6 % length4], isCategorical, isCategorical2, i6, true, dArr, dArr2, dArr3, PartialDependence.this._predictor_columns[length])));
                    }
                }
                futures.blockForPending();
                if (z) {
                    TwoDimTable[] twoDimTableArr = PartialDependence.this._partial_dependence_data;
                    int i7 = i3;
                    String str3 = PartialDependence.this._row_index < 0 ? "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + str + "'" + (PartialDependence.this._targets == null ? "." : " and class " + PartialDependence.this._targets[length]) : "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + str + "'" + (PartialDependence.this._targets == null ? "'" : " and class " + PartialDependence.this._targets[length]) + " for row index" + PartialDependence.this._row_index;
                    String[] strArr = new String[extractColValues.length];
                    String[] strArr2 = {str, "mean_response", "stddev_response", "std_error_mean_response"};
                    String[] strArr3 = new String[4];
                    strArr3[0] = isCategorical ? "string" : "double";
                    strArr3[1] = "double";
                    strArr3[2] = "double";
                    strArr3[3] = "double";
                    String[] strArr4 = new String[4];
                    strArr4[0] = isCategorical ? "%s" : "%5f";
                    strArr4[1] = "%5f";
                    strArr4[2] = "%5f";
                    strArr4[3] = "%5f";
                    twoDimTableArr[i7] = new TwoDimTable("PartialDependence", str3, strArr, strArr2, strArr3, strArr4, null);
                } else {
                    TwoDimTable[] twoDimTableArr2 = PartialDependence.this._partial_dependence_data;
                    int i8 = i3;
                    String str4 = PartialDependence.this._row_index < 0 ? "2D Partial Dependence Plot of model " + PartialDependence.this._model_id + " on 1st column '" + str + "' and 2nd column '" + str2 + "'" : "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on columns '" + str + "', '" + str2 + "' for row " + PartialDependence.this._row_index;
                    String[] strArr5 = new String[extractColValues.length * extractColValues2.length];
                    String[] strArr6 = {str, str2, "mean_response", "stddev_response", "std_error_mean_response"};
                    String[] strArr7 = new String[5];
                    strArr7[0] = isCategorical ? "string" : "double";
                    strArr7[1] = isCategorical2 ? "string" : "double";
                    strArr7[2] = "double";
                    strArr7[3] = "double";
                    strArr7[4] = "double";
                    String[] strArr8 = new String[5];
                    strArr8[0] = isCategorical ? "%s" : "%5f";
                    strArr8[1] = isCategorical2 ? "%s" : "%5f";
                    strArr8[2] = "%5f";
                    strArr8[3] = "%5f";
                    strArr8[4] = "%5f";
                    twoDimTableArr2[i8] = new TwoDimTable("2D-PartialDependence", str4, strArr5, strArr6, strArr7, strArr8, null);
                }
                for (int i9 = 0; i9 < dArr.length; i9++) {
                    int length5 = z ? i9 : i9 / extractColValues2.length;
                    if (!frame.vec(str).isCategorical()) {
                        PartialDependence.this._partial_dependence_data[i3].set(i9, 0, Double.valueOf(extractColValues[length5]));
                    } else if (PartialDependence.this._add_missing_na && Double.isNaN(extractColValues[length5])) {
                        PartialDependence.this._partial_dependence_data[i3].set(i9, 0, ".missing(NA)");
                    } else {
                        PartialDependence.this._partial_dependence_data[i3].set(i9, 0, frame.vec(str).domain()[(int) extractColValues[length5]]);
                    }
                    int i10 = 0 + 1;
                    if (!z) {
                        int length6 = i9 % extractColValues2.length;
                        if (!frame.vec(str2).isCategorical()) {
                            PartialDependence.this._partial_dependence_data[i3].set(i9, i10, Double.valueOf(extractColValues2[length6]));
                        } else if (PartialDependence.this._add_missing_na && Double.isNaN(extractColValues2[length6])) {
                            PartialDependence.this._partial_dependence_data[i3].set(i9, i10, ".missing(NA)");
                        } else {
                            PartialDependence.this._partial_dependence_data[i3].set(i9, i10, frame.vec(str2).domain()[(int) extractColValues2[length6]]);
                        }
                        i10++;
                    }
                    int i11 = i10;
                    int i12 = i10 + 1;
                    PartialDependence.this._partial_dependence_data[i3].set(i9, i11, Double.valueOf(dArr[i9]));
                    int i13 = i12 + 1;
                    PartialDependence.this._partial_dependence_data[i3].set(i9, i12, Double.valueOf(dArr2[i9]));
                    int i14 = i13 + 1;
                    PartialDependence.this._partial_dependence_data[i3].set(i9, i13, Double.valueOf(dArr3[i9]));
                }
                if (PartialDependence.this._targets == null) {
                    i2++;
                } else if ((i3 + 1) % PartialDependence.this._targets.length == 0) {
                    i2++;
                }
                PartialDependence.this._job.update(1L);
                PartialDependence.this.update(PartialDependence.this._job);
                if (PartialDependence.this._job.stop_requested()) {
                    break;
                } else {
                    i3++;
                }
            }
            tryComplete();
        }

        public FrameUtils.CalculateWeightMeanSTD getWeightedStat(Frame frame, Frame frame2, int i) {
            FrameUtils.CalculateWeightMeanSTD calculateWeightMeanSTD = new FrameUtils.CalculateWeightMeanSTD();
            calculateWeightMeanSTD.doAll(frame2.vec(i), frame.vec(PartialDependence.this._weight_column_index));
            return calculateWeightMeanSTD;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsr166y.CountedCompleter
        public void onCompletion(CountedCompleter countedCompleter) {
            PartialDependence.this._frame_id.get().unlock((Key<Job>) PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsr166y.CountedCompleter
        public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            PartialDependence.this._frame_id.get().unlock((Key<Job>) PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
            return true;
        }

        static {
            $assertionsDisabled = !PartialDependence.class.desiredAssertionStatus();
        }
    }

    public PartialDependence(Key<PartialDependence> key, Job job) {
        super(key);
        this._row_index = -1L;
        this._weight_column_index = -1;
        this._add_missing_na = false;
        this._nbins = 20;
        this._user_splits = null;
        this._user_split_per_col = (double[][]) null;
        this._num_user_splits = null;
        this._user_cols = null;
        this._user_splits_present = false;
        this._col_pairs_2dpdp = (String[][]) null;
        this._num_2D_pairs = 0;
        this._num_1D = 0;
        this._job = job;
    }

    public PartialDependence(Key<PartialDependence> key) {
        this(key, new Job(key, PartialDependence.class.getName(), "PartialDependence"));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public PartialDependence execNested() {
        checkSanityAndFillParams();
        delete_and_lock(this._job);
        this._frame_id.get().write_lock((Key<Job>) this._job._key);
        new PartialDependenceDriver().compute2();
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Job<PartialDependence> execImpl() {
        checkSanityAndFillParams();
        delete_and_lock(this._job);
        this._frame_id.get().write_lock((Key<Job>) this._job._key);
        this._job.start(new PartialDependenceDriver(), this._num_1D + this._num_2D_pairs);
        return this._job;
    }

    private int findTargetClassPredictorIndex(Model model, String str) {
        int indexOf = Arrays.asList(model._output.classNames()).indexOf(str);
        if (indexOf == -1) {
            throw new IllegalArgumentException("Incorrect target class: " + str + ".");
        }
        return indexOf + 1;
    }

    private int[] findTargetClassPredictorIndices(Model model, String[] strArr) {
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = findTargetClassPredictorIndex(model, strArr[i]);
        }
        return iArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v58, types: [double[], double[][]] */
    private void checkSanityAndFillParams() {
        Model model = this._model_id.get();
        if (model == 0) {
            throw new IllegalArgumentException("Model not found.");
        }
        if (!model._output.isSupervised()) {
            throw new IllegalArgumentException("Partial dependence plots are only implemented for supervised models");
        }
        int nclasses = model._output.nclasses();
        if (nclasses <= 2 && this._targets != null) {
            throw new IllegalArgumentException("Targets parameter is available only for multinomial classification.");
        }
        if (nclasses == 1) {
            this._predictor_column = 0;
            this._predictor_columns = new int[]{this._predictor_column};
        } else if (nclasses == 2) {
            this._predictor_column = 2;
            this._predictor_columns = new int[]{this._predictor_column};
        } else {
            if (this._targets == null) {
                throw new IllegalArgumentException("Targets parameter has to be set for multinomial classification.");
            }
            this._predictor_columns = findTargetClassPredictorIndices(model, this._targets);
        }
        if (this._cols == null && this._col_pairs_2dpdp == null) {
            this._cols_1d_2d = null;
        } else {
            this._cols_1d_2d = new ArrayList<>();
            if (this._cols != null) {
                this._cols_1d_2d.addAll(Arrays.asList(this._cols));
            }
            if (this._col_pairs_2dpdp != null) {
                this._num_2D_pairs = this._col_pairs_2dpdp.length * this._predictor_columns.length;
                for (int i = 0; i < this._num_2D_pairs; i++) {
                    if (!this._cols_1d_2d.contains(this._col_pairs_2dpdp[i][0])) {
                        this._cols_1d_2d.add(this._col_pairs_2dpdp[i][0]);
                    }
                    if (!this._cols_1d_2d.contains(this._col_pairs_2dpdp[i][1])) {
                        this._cols_1d_2d.add(this._col_pairs_2dpdp[i][1]);
                    }
                }
            }
        }
        if (this._cols_1d_2d == null) {
            if (this._frame_id.get() == null) {
                throw new IllegalArgumentException("Frame not found.");
            }
            if (Model.GetMostImportantFeatures.class.isAssignableFrom(model.getClass())) {
                this._cols = ((Model.GetMostImportantFeatures) model).getMostImportantFeatures(10);
                if (this._cols != null) {
                    Log.info("Selecting the top " + this._cols.length + " features from the model's variable importances.");
                }
            } else {
                this._cols = model._output._names;
                if (this._cols != null) {
                    Log.info("Selecting all features from the training data.");
                }
            }
            this._cols_1d_2d = new ArrayList<>();
            this._cols_1d_2d.addAll(Arrays.asList(this._cols));
        }
        this._num_1D = this._cols == null ? 0 : this._cols.length * this._predictor_columns.length;
        if (this._nbins < 2) {
            throw new IllegalArgumentException("_nbins must be >=2.");
        }
        if (this._user_splits != null && this._user_splits.length > 0) {
            this._user_splits_present = true;
            int length = this._user_cols.length;
            this._user_split_per_col = new double[length];
            int[] iArr = new int[length];
            for (int i2 = 1; i2 < length; i2++) {
                iArr[i2] = this._num_user_splits[i2 - 1] + iArr[i2 - 1];
            }
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = this._num_user_splits[i3];
                this._user_split_per_col[i3] = new double[i4];
                System.arraycopy(this._user_splits, iArr[i3], this._user_split_per_col[i3], 0, i4);
            }
        }
        Frame frame = this._frame_id.get();
        if (this._weight_column_index >= 0 && (!frame.vec(this._weight_column_index).isNumeric() || frame.vec(this._weight_column_index).isCategorical())) {
            throw new IllegalArgumentException("Weight column " + this._weight_column_index + " must be a numerical column.");
        }
        for (int i5 = 0; i5 < this._cols_1d_2d.size(); i5++) {
            String str = this._cols_1d_2d.get(i5);
            Vec vec = frame.vec(str);
            if (vec.isCategorical() && vec.cardinality() > this._nbins) {
                throw new IllegalArgumentException("Column " + str + "'s cardinality of " + vec.cardinality() + " > nbins of " + this._nbins);
            }
        }
    }

    double[] extractColValues(String str, int i, Vec vec) {
        double[] dArr;
        if (this._user_splits_present && Arrays.asList(this._user_cols).contains(str)) {
            int indexOf = Arrays.asList(this._user_cols).indexOf(str);
            i = this._num_user_splits[indexOf];
            dArr = this._add_missing_na ? new double[this._num_user_splits[indexOf] + 1] : new double[this._num_user_splits[indexOf]];
            for (int i2 = 0; i2 < this._num_user_splits[indexOf]; i2++) {
                dArr[i2] = this._user_split_per_col[indexOf][i2];
            }
        } else {
            if (vec.isInt() && (vec.max() - vec.min()) + 1.0d < this._nbins) {
                i = (int) ((vec.max() - vec.min()) + 1.0d);
            }
            dArr = this._add_missing_na ? new double[i + 1] : new double[i];
            double max = (vec.max() - vec.min()) / (i - 1);
            if (i == 1) {
                max = 0.0d;
            }
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = vec.min() + (i3 * max);
            }
        }
        if (this._add_missing_na) {
            dArr[i] = Double.NaN;
        }
        Log.debug("Computing PartialDependence for column " + str + " at the following values: ");
        Log.debug(Arrays.toString(dArr));
        return dArr;
    }

    @Override // water.Keyed
    public Class<KeyV3.PartialDependenceKeyV3> makeSchema() {
        return KeyV3.PartialDependenceKeyV3.class;
    }
}
