package hex.grid;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelParametersBuilderFactory;
import hex.ScoreKeeper;
import hex.ScoringInfo;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Stream;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.PojoUtils;

/* loaded from: input_file:hex/grid/HyperSpaceWalker.class */
public interface HyperSpaceWalker<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> {

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$BaseWalker.class */
    public static abstract class BaseWalker<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> implements HyperSpaceWalker<MP, C> {
        protected final C _search_criteria;
        public static final String SUBSPACES = "subspaces";
        final transient ModelParametersBuilderFactory<MP> _paramsBuilderFactory;
        final MP _params;
        final MP _defaultParams;
        protected final Map<String, Object[]> _hyperParams;
        protected final String[] _hyperParamNames;
        protected final long _maxHyperSpaceSize;
        long model_number = 0;
        protected Map<String, Object[]>[] _hyperParamSubspaces = extractSubspaces();
        protected final String[] _hyperParamNamesSubspace = extractSubspaceNames();

        /* loaded from: input_file:hex/grid/HyperSpaceWalker$BaseWalker$WalkerFactory.class */
        public static class WalkerFactory<MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> {
            public static <MP extends Model.Parameters, C extends HyperSpaceSearchCriteria> HyperSpaceWalker<MP, ? extends HyperSpaceSearchCriteria> create(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, C c) {
                HyperSpaceSearchCriteria.Strategy strategy = c.strategy();
                switch (strategy) {
                    case Cartesian:
                        return new CartesianWalker(mp, map, modelParametersBuilderFactory, (HyperSpaceSearchCriteria.CartesianSearchCriteria) c);
                    case RandomDiscrete:
                        return new RandomDiscreteValueWalker(mp, map, modelParametersBuilderFactory, (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria) c);
                    default:
                        throw new H2OIllegalArgumentException("strategy", "GridSearch", strategy);
                }
            }
        }

        @Override // hex.grid.HyperSpaceWalker
        public C search_criteria() {
            return this._search_criteria;
        }

        @Override // hex.grid.HyperSpaceWalker
        public boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr) {
            return false;
        }

        public BaseWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, C c) {
            this._params = mp;
            this._hyperParams = map;
            this._paramsBuilderFactory = modelParametersBuilderFactory;
            this._hyperParamNames = (String[]) map.keySet().toArray(new String[0]);
            this._hyperParams.remove(SUBSPACES);
            this._search_criteria = c;
            this._maxHyperSpaceSize = computeMaxSizeOfHyperSpace();
            try {
                this._defaultParams = (MP) mp.getClass().newInstance();
                validateParams(this._hyperParams, false);
                Arrays.stream(this._hyperParamSubspaces).forEach(map2 -> {
                    validateParams(map2, true);
                });
            } catch (Exception e) {
                throw new H2OIllegalArgumentException("Failed to instantiate a new Model.Parameters object to get the default values.");
            }
        }

        @Override // hex.grid.HyperSpaceWalker
        public String[] getHyperParamNames() {
            return this._hyperParamNames;
        }

        @Override // hex.grid.HyperSpaceWalker
        public String[] getAllHyperParamNamesInSubspaces() {
            return this._hyperParamNamesSubspace;
        }

        @Override // hex.grid.HyperSpaceWalker
        public Map<String, Object[]> getHyperParams() {
            return this._hyperParams;
        }

        @Override // hex.grid.HyperSpaceWalker
        public long getMaxHyperSpaceSize() {
            return this._maxHyperSpaceSize;
        }

        @Override // hex.grid.HyperSpaceWalker
        public MP getParams() {
            return this._params;
        }

        @Override // hex.grid.HyperSpaceWalker
        public ModelParametersBuilderFactory<MP> getParametersBuilderFactory() {
            return this._paramsBuilderFactory;
        }

        private Map<String, Object[]>[] extractSubspaces() {
            return !this._hyperParams.containsKey(SUBSPACES) ? new Map[0] : (Map[]) this._hyperParams.get(SUBSPACES);
        }

        private String[] extractSubspaceNames() {
            return (String[]) Stream.of((Object[]) this._hyperParamSubspaces).flatMap(map -> {
                return map.keySet().stream();
            }).toArray(i -> {
                return new String[i];
            });
        }

        protected MP getModelParams(MP mp, Object[] objArr, String[] strArr) {
            ModelParametersBuilderFactory.ModelParametersBuilder<MP> modelParametersBuilder = this._paramsBuilderFactory.get(mp);
            for (int i = 0; i < strArr.length; i++) {
                String str = strArr[i];
                Object obj = objArr[i];
                if (str.equals("valid")) {
                    str = "validation_frame";
                }
                modelParametersBuilder.set(str, obj);
            }
            return modelParametersBuilder.build();
        }

        protected long computeMaxSizeOfHyperSpace() {
            long j = 0;
            long j2 = 1;
            for (Map<String, Object[]> map : this._hyperParamSubspaces) {
                long j3 = 1;
                while (map.values().iterator().hasNext()) {
                    j3 *= r0.next().length;
                }
                j += j3;
            }
            while (this._hyperParams.values().iterator().hasNext()) {
                j2 *= r0.next().length;
            }
            return j == 0 ? j2 : j2 * j;
        }

        protected Map<String, Object[]> mergeHashMaps(Map<String, Object[]> map, Map<String, Object[]> map2) {
            if (map2 == null) {
                return map;
            }
            HashMap hashMap = new HashMap();
            hashMap.putAll(map);
            hashMap.putAll(map2);
            return hashMap;
        }

        protected Object[] hypers(Map<String, Object[]> map, String[] strArr, int[] iArr) {
            Object[] objArr = new Object[strArr.length];
            for (int i = 0; i < iArr.length; i++) {
                objArr[i] = map.get(strArr[i])[iArr[i]];
            }
            return objArr;
        }

        protected int integerHash(Map<String, Object[]> map, String[] strArr, int[] iArr, int i) {
            Integer[] numArr = new Integer[iArr.length + 1];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                numArr[i2] = Integer.valueOf(iArr[i2] * map.get(strArr[i2]).length);
            }
            numArr[iArr.length] = Integer.valueOf(i);
            return Arrays.deepHashCode(numArr);
        }

        private void validateParams(Map<String, Object[]> map, boolean z) {
            for (String str : map.keySet()) {
                if (0 == map.get(str).length) {
                    throw new H2OIllegalArgumentException("Grid search hyperparameter value list is empty for hyperparameter: " + str);
                }
                if (!"seed".equals(str) && !"_seed".equals(str)) {
                    if (z && this._hyperParams.containsKey(str)) {
                        throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the subspaces and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                    }
                    validateParamVals(str);
                }
            }
        }

        private void validateParamVals(String str) {
            String str2 = str.startsWith("_") ? "" : "_";
            Object fieldValue = PojoUtils.getFieldValue(this._defaultParams, str2 + str, PojoUtils.FieldNaming.CONSISTENT);
            Object fieldValue2 = PojoUtils.getFieldValue(this._params, str2 + str, PojoUtils.FieldNaming.CONSISTENT);
            if (fieldValue != null && fieldValue2 != null) {
                if (fieldValue.getClass().isArray() && !PojoUtils.arraysEquals(fieldValue, fieldValue2)) {
                    throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                }
                if (!fieldValue.getClass().isArray() && !fieldValue.equals(fieldValue2)) {
                    throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
                }
            }
            if (fieldValue == null && fieldValue2 != null) {
                throw new H2OIllegalArgumentException("Grid search model parameter '" + str + "' is set in both the model parameters and in the hyperparameters map.  This is ambiguous; set it in one place or the other, not both.");
            }
        }
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$CartesianWalker.class */
    public static class CartesianWalker<MP extends Model.Parameters> extends BaseWalker<MP, HyperSpaceSearchCriteria.CartesianSearchCriteria> {
        public CartesianWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, HyperSpaceSearchCriteria.CartesianSearchCriteria cartesianSearchCriteria) {
            super(mp, map, modelParametersBuilderFactory, cartesianSearchCriteria);
        }

        @Override // hex.grid.HyperSpaceWalker
        public HyperSpaceIterator<MP> iterator() {
            return (HyperSpaceIterator<MP>) new HyperSpaceIterator<MP>() { // from class: hex.grid.HyperSpaceWalker.CartesianWalker.1
                private int[] _currentHyperparamIndices = null;
                private int _currentSubspace;
                private Map<String, Object[]> _currentHyperParams;
                private String[] _currentHyperParamNames;

                {
                    this._currentSubspace = CartesianWalker.this._hyperParamSubspaces.length == 0 ? -1 : 0;
                    this._currentHyperParams = CartesianWalker.this._hyperParamSubspaces.length == 0 ? CartesianWalker.this._hyperParams : CartesianWalker.this.mergeHashMaps(CartesianWalker.this._hyperParams, CartesianWalker.this._hyperParamSubspaces[0]);
                    this._currentHyperParamNames = (String[]) this._currentHyperParams.keySet().toArray(new String[0]);
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public MP nextModelParameters() {
                    this._currentHyperparamIndices = this._currentHyperparamIndices == null ? new int[this._currentHyperParamNames.length] : nextModelIndices(this._currentHyperparamIndices);
                    if (this._currentSubspace < CartesianWalker.this._hyperParamSubspaces.length - 1 && this._currentHyperparamIndices == null) {
                        CartesianWalker cartesianWalker = CartesianWalker.this;
                        Map<String, Object[]> map = CartesianWalker.this._hyperParams;
                        Map<String, Object[]>[] mapArr = CartesianWalker.this._hyperParamSubspaces;
                        int i = this._currentSubspace + 1;
                        this._currentSubspace = i;
                        this._currentHyperParams = cartesianWalker.mergeHashMaps(map, mapArr[i]);
                        this._currentHyperParamNames = (String[]) this._currentHyperParams.keySet().toArray(new String[0]);
                        this._currentHyperparamIndices = new int[this._currentHyperParamNames.length];
                    }
                    if (this._currentHyperparamIndices == null) {
                        throw new NoSuchElementException("No more elements to explore in hyper-space!");
                    }
                    return (MP) CartesianWalker.this.getModelParams((Model.Parameters) CartesianWalker.this._params.m381clone(), CartesianWalker.this.hypers(this._currentHyperParams, this._currentHyperParamNames, this._currentHyperparamIndices), this._currentHyperParamNames);
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public boolean hasNext() {
                    if (this._currentHyperparamIndices == null) {
                        return true;
                    }
                    int[] iArr = new int[this._currentHyperparamIndices.length];
                    System.arraycopy(this._currentHyperparamIndices, 0, iArr, 0, this._currentHyperparamIndices.length);
                    return (nextModelIndices(iArr) == null && this._currentSubspace == CartesianWalker.this._hyperParamSubspaces.length - 1) ? false : true;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void onModelFailure(Model model, Consumer<Object[]> consumer) {
                    consumer.accept(CartesianWalker.this.hypers(this._currentHyperParams, this._currentHyperParamNames, this._currentHyperparamIndices));
                }

                private int[] nextModelIndices(int[] iArr) {
                    int i = 0;
                    while (i < iArr.length && iArr[i] + 1 >= this._currentHyperParams.get(this._currentHyperParamNames[i]).length) {
                        i++;
                    }
                    if (i == iArr.length) {
                        return null;
                    }
                    for (int i2 = 0; i2 < i; i2++) {
                        iArr[i2] = 0;
                    }
                    int i3 = i;
                    iArr[i3] = iArr[i3] + 1;
                    return iArr;
                }
            };
        }
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$HyperSpaceIterator.class */
    public interface HyperSpaceIterator<MP extends Model.Parameters> {
        MP nextModelParameters();

        boolean hasNext();

        void onModelFailure(Model model, Consumer<Object[]> consumer);
    }

    /* loaded from: input_file:hex/grid/HyperSpaceWalker$RandomDiscreteValueWalker.class */
    public static class RandomDiscreteValueWalker<MP extends Model.Parameters> extends BaseWalker<MP, HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria> {
        private static final double MIN_NUMBER_OF_SAMPLES = 10000.0d;
        private Random _random;
        private boolean _set_model_seed_from_search_seed;

        public RandomDiscreteValueWalker(MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria) {
            super(mp, map, modelParametersBuilderFactory, randomDiscreteValueSearchCriteria);
            long j = this._defaultParams._seed;
            long j2 = this._params._seed;
            long seed = randomDiscreteValueSearchCriteria.seed();
            this._set_model_seed_from_search_seed = j == j2 && j != seed;
            this._random = seed == j ? new Random() : new Random(seed);
        }

        @Override // hex.grid.HyperSpaceWalker.BaseWalker, hex.grid.HyperSpaceWalker
        public boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr) {
            return ScoreKeeper.stopEarly(ScoringInfo.scoreKeepers(scoringInfoArr), search_criteria().stopping_rounds(), ScoreKeeper.ProblemType.forSupervised(model._output.isClassifier()), search_criteria().stopping_metric(), search_criteria().stopping_tolerance(), "grid's best", true);
        }

        @Override // hex.grid.HyperSpaceWalker
        public HyperSpaceIterator<MP> iterator() {
            return (HyperSpaceIterator<MP>) new HyperSpaceIterator<MP>() { // from class: hex.grid.HyperSpaceWalker.RandomDiscreteValueWalker.1
                private Map<String, Object[]> _currentHyperParams;
                private String[] _currentHyperParamNames;
                private final List<int[]> _visitedPermutations = new ArrayList();
                private final Set<Integer> _visitedPermutationHashes = new LinkedHashSet();
                private int[] _currentHyperparamIndices = null;
                private int _currentPermutationNum = 0;
                private int _currentSubspace = -1;
                private boolean _exhausted = false;

                {
                    this._currentHyperParams = RandomDiscreteValueWalker.this._hyperParams;
                    this._currentHyperParamNames = RandomDiscreteValueWalker.this._hyperParamNames;
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public MP nextModelParameters() {
                    this._currentHyperparamIndices = nextModelIndices();
                    if (this._currentHyperparamIndices == null) {
                        throw new NoSuchElementException("No more elements to explore in hyper-space!");
                    }
                    this._visitedPermutations.add(this._currentHyperparamIndices);
                    this._visitedPermutationHashes.add(Integer.valueOf(RandomDiscreteValueWalker.this.integerHash(this._currentHyperParams, this._currentHyperParamNames, this._currentHyperparamIndices, this._currentSubspace)));
                    this._currentPermutationNum++;
                    MP mp = (MP) RandomDiscreteValueWalker.this.getModelParams((Model.Parameters) RandomDiscreteValueWalker.this._params.m381clone(), RandomDiscreteValueWalker.this.hypers(this._currentHyperParams, this._currentHyperParamNames, this._currentHyperparamIndices), this._currentHyperParamNames);
                    if (RandomDiscreteValueWalker.this._search_criteria != null && ((HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria) RandomDiscreteValueWalker.this._search_criteria).strategy() == HyperSpaceSearchCriteria.Strategy.RandomDiscrete && RandomDiscreteValueWalker.this._set_model_seed_from_search_seed) {
                        long seed = ((HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria) RandomDiscreteValueWalker.this._search_criteria).seed();
                        RandomDiscreteValueWalker randomDiscreteValueWalker = RandomDiscreteValueWalker.this;
                        long j = randomDiscreteValueWalker.model_number;
                        randomDiscreteValueWalker.model_number = j + 1;
                        mp._seed = seed + j;
                    }
                    return mp;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public boolean hasNext() {
                    return ((long) this._visitedPermutationHashes.size()) < RandomDiscreteValueWalker.this._maxHyperSpaceSize && (RandomDiscreteValueWalker.this.search_criteria().max_models() == 0 || this._currentPermutationNum < RandomDiscreteValueWalker.this.search_criteria().max_models()) && !this._exhausted;
                }

                @Override // hex.grid.HyperSpaceWalker.HyperSpaceIterator
                public void onModelFailure(Model model, Consumer<Object[]> consumer) {
                    this._currentPermutationNum--;
                    consumer.accept(RandomDiscreteValueWalker.this.hypers(this._currentHyperParams, this._currentHyperParamNames, this._currentHyperparamIndices));
                }

                private int[] nextModelIndices() {
                    int[] iArr = new int[this._currentHyperParamNames.length];
                    for (int i = 0; i < Math.max(RandomDiscreteValueWalker.MIN_NUMBER_OF_SAMPLES, RandomDiscreteValueWalker.this._maxHyperSpaceSize); i++) {
                        if (RandomDiscreteValueWalker.this._hyperParamSubspaces.length != 0) {
                            this._currentSubspace = RandomDiscreteValueWalker.this._random.nextInt(RandomDiscreteValueWalker.this._hyperParamSubspaces.length);
                            this._currentHyperParams = RandomDiscreteValueWalker.this.mergeHashMaps(RandomDiscreteValueWalker.this._hyperParams, RandomDiscreteValueWalker.this._hyperParamSubspaces[this._currentSubspace]);
                            this._currentHyperParamNames = (String[]) this._currentHyperParams.keySet().toArray(new String[0]);
                            iArr = new int[this._currentHyperParamNames.length];
                        }
                        for (int i2 = 0; i2 < this._currentHyperParamNames.length; i2++) {
                            iArr[i2] = RandomDiscreteValueWalker.this._random.nextInt(this._currentHyperParams.get(this._currentHyperParamNames[i2]).length);
                        }
                        if (!this._visitedPermutationHashes.contains(Integer.valueOf(RandomDiscreteValueWalker.this.integerHash(this._currentHyperParams, this._currentHyperParamNames, iArr, this._currentSubspace)))) {
                            return iArr;
                        }
                    }
                    this._exhausted = true;
                    return null;
                }
            };
        }

        @Override // hex.grid.HyperSpaceWalker
        public long estimateGridWork(long j) {
            return Long.MAX_VALUE;
        }
    }

    C search_criteria();

    boolean stopEarly(Model model, ScoringInfo[] scoringInfoArr);

    HyperSpaceIterator<MP> iterator();

    String[] getHyperParamNames();

    String[] getAllHyperParamNamesInSubspaces();

    default String[] getAllHyperParamNames() {
        String[] hyperParamNames = getHyperParamNames();
        String[] strArr = hyperParamNames;
        String[] allHyperParamNamesInSubspaces = getAllHyperParamNamesInSubspaces();
        if (allHyperParamNamesInSubspaces.length > 0) {
            strArr = ArrayUtils.append(ArrayUtils.remove(hyperParamNames, BaseWalker.SUBSPACES), allHyperParamNamesInSubspaces);
        }
        return strArr;
    }

    Map<String, Object[]> getHyperParams();

    long getMaxHyperSpaceSize();

    MP getParams();

    default long estimateGridWork(long j) {
        HyperSpaceIterator<MP> it = iterator();
        long j2 = 0;
        if (getMaxHyperSpaceSize() <= 0 || j <= 0) {
            j2 = Long.MAX_VALUE;
        } else {
            while (it.hasNext()) {
                try {
                    MP nextModelParameters = it.nextModelParameters();
                    j2 += (nextModelParameters._nfolds > 0 ? nextModelParameters._nfolds + 1 : 1) * nextModelParameters.progressUnits();
                } catch (Throwable th) {
                }
            }
        }
        return j2;
    }

    ModelParametersBuilderFactory<MP> getParametersBuilderFactory();
}
