/*
 * Decompiled with CFR 0.152.
 */
package hex.grid;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.ParallelModelBuilder;
import hex.ScoreKeeper;
import hex.ScoringInfo;
import hex.grid.Grid;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import jsr166y.CountedCompleter;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.KeySnapshot;
import water.Keyed;
import water.Value;
import water.exceptions.H2OConcurrentModificationException;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;

public final class GridSearch<MP extends Model.Parameters>
extends Keyed<GridSearch> {
    public final Key<Grid> _result;
    public final Job<Grid> _job;
    public final int _parallelism;
    private final transient HyperSpaceWalker<MP, ?> _hyperSpaceWalker;
    private static final Set<String> IGNORED_FIELDS_PARAM_HASH = Collections.singleton("_export_checkpoints_dir");
    public static final int ADAPTIVE_PARALLELISM_LEVEL = 0;
    public static final int SEQUENTIAL_MODEL_BUILDING = 1;

    private GridSearch(Key<Grid> gkey, HyperSpaceWalker<MP, ?> hyperSpaceWalker, int parallelism) {
        assert (hyperSpaceWalker != null) : "Grid search needs to know how to walk around hyper space!";
        this._hyperSpaceWalker = hyperSpaceWalker;
        this._result = gkey;
        String algoName = ((Model.Parameters)hyperSpaceWalker.getParams()).algoName();
        this._job = new Job<Grid>(gkey, Grid.class.getName(), algoName + " Grid Search");
        this._parallelism = parallelism;
    }

    Job<Grid> start() {
        Grid<MP> grid;
        long gridSize = this._hyperSpaceWalker.getMaxHyperSpaceSize();
        Log.info("Starting gridsearch: estimated size of search space = " + gridSize);
        Keyed keyed = (Keyed)DKV.getGet(this._result);
        if (keyed != null) {
            if (!(keyed instanceof Grid)) {
                throw new H2OIllegalArgumentException("Name conflict: tried to create a Grid using the ID of a non-Grid object that's already in H2O: " + this._job._result + "; it is a: " + keyed.getClass());
            }
            grid = (Grid<MP>)keyed;
            grid.clearNonRelatedFailures();
            Frame specTrainFrame = ((Model.Parameters)this._hyperSpaceWalker.getParams()).train();
            Frame oldTrainFrame = grid.getTrainingFrame();
            if (oldTrainFrame != null && !specTrainFrame._key.equals(oldTrainFrame._key) || oldTrainFrame != null && specTrainFrame.checksum() != oldTrainFrame.checksum()) {
                throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input");
            }
            grid.write_lock(this._job);
        } else {
            grid = new Grid<MP>(this._result, this._hyperSpaceWalker.getParams(), this._hyperSpaceWalker.getHyperParamNames(), this._hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
            grid.delete_and_lock(this._job);
        }
        Model model = null;
        HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
        long gridWork = 0L;
        if (gridSize > 0L && this.maxModels() > 0L) {
            while (it.hasNext(model)) {
                try {
                    MP parms = it.nextModelParameters(model);
                    gridWork += (long)(((Model.Parameters)parms)._nfolds > 0 ? ((Model.Parameters)parms)._nfolds + 1 : 1) * ((Model.Parameters)parms).progressUnits();
                }
                catch (Throwable throwable) {}
            }
        } else {
            gridWork = Long.MAX_VALUE;
        }
        return this._job.start(new H2O.H2OCountedCompleter(){

            @Override
            public void compute2() {
                if (GridSearch.this._parallelism == 1) {
                    GridSearch.this.gridSearch(grid);
                } else if (GridSearch.this._parallelism > 1) {
                    GridSearch.this.parallelGridSearch(grid);
                } else {
                    throw new IllegalArgumentException(String.format("Grid search parallelism level must be >= 1. Give value is '%d'.", GridSearch.this._parallelism));
                }
                this.tryComplete();
            }

            @Override
            public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
                Log.warn("GridSearch job " + GridSearch.this._job._description + " completed with exception: " + ex);
                return true;
            }
        }, gridWork, this.maxRuntimeSecs());
    }

    public long getModelCount() {
        return this._hyperSpaceWalker.getMaxHyperSpaceSize();
    }

    private long maxModels() {
        return ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria() == null ? 0L : (long)((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria().getMaxModels();
    }

    private double maxRuntimeSecs() {
        return ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria() == null ? 0.0 : ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria().getMaxRuntimeSecs();
    }

    private double remainingTimeSecs() {
        return this._job != null && this._job._max_runtime_msecs > 0L ? (double)(this._job.start_time() + this._job._max_runtime_msecs - System.currentTimeMillis()) / 1000.0 : Double.MAX_VALUE;
    }

    private ScoreKeeper.StoppingMetric sortingMetric() {
        return ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria() == null ? ScoreKeeper.StoppingMetric.AUTO : ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stoppingCriteria().getStoppingMetric();
    }

    private void parallelGridSearch(Grid<MP> grid) {
        HyperSpaceWalker.HyperSpaceIterator<MP> iterator = this._hyperSpaceWalker.iterator();
        ModelFeeder modelFeeder = new ModelFeeder(iterator, grid);
        ParallelModelBuilder parallelModelBuilder = new ParallelModelBuilder(modelFeeder);
        ArrayList<ModelBuilder> startModels = new ArrayList<ModelBuilder>();
        List<MP> mps = this.initialModelParameters(this._parallelism, iterator);
        for (int i = 0; i < mps.size(); ++i) {
            Model.Parameters nextModelParameters = (Model.Parameters)mps.get(i);
            long checksum = nextModelParameters.checksum(IGNORED_FIELDS_PARAM_HASH);
            if (grid.getModelKey(checksum) != null) continue;
            startModels.add((ModelBuilder)ModelBuilder.make(nextModelParameters));
        }
        if (!startModels.isEmpty()) {
            parallelModelBuilder.run(startModels);
            parallelModelBuilder.join();
        }
        grid.update(this._job);
        this.attemptGridSave(grid);
        grid.unlock(this._job);
    }

    public List<MP> initialModelParameters(int numParams, HyperSpaceWalker.HyperSpaceIterator<MP> iterator) {
        ArrayList<MP> parameters = new ArrayList<MP>(numParams);
        for (int i = 0; i < numParams && iterator.hasNext(null); ++i) {
            parameters.add(iterator.nextModelParameters(null));
        }
        return parameters;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void gridSearch(Grid<MP> grid) {
        Model model = null;
        String protoModelKey = grid._key + "_model_";
        try {
            HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
            int counter = grid.getModelCount();
            while (it.hasNext(model)) {
                if (this._job.stop_requested()) {
                    throw new Job.JobCancelledException();
                }
                Model.Parameters params = null;
                try {
                    params = (Model.Parameters)it.nextModelParameters(model);
                    this.reconcileMaxRuntime(grid._key, params);
                    try {
                        ScoringInfo scoringInfo = new ScoringInfo();
                        scoringInfo.time_stamp_ms = System.currentTimeMillis();
                        model = this.buildModel(params, grid, ++counter, protoModelKey);
                        if (model != null) {
                            model.fillScoringInfo(scoringInfo);
                            grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, grid.getScoringInfos()));
                            ScoringInfo.sort(grid.getScoringInfos(), this.sortingMetric());
                        }
                    }
                    catch (RuntimeException e) {
                        if (!Job.isCancelledException(e)) {
                            Log.warn("Grid search: model builder for parameters " + params + " failed! Exception: ", e);
                            Log.err(e);
                        }
                        grid.appendFailedModelParameters((Key<Model>)(model != null ? model._key : null), params, (Throwable)e);
                    }
                }
                catch (IllegalArgumentException e) {
                    Log.warn("Grid search: construction of model parameters failed! Exception: ", e);
                    Model failedModel = model;
                    it.onModelFailure(failedModel, failedHyperParams -> grid.appendFailedModelParameters((Key<Model>)(failedModel != null ? failedModel._key : null), (Object[])failedHyperParams, e));
                }
                finally {
                    this._job.update(1L);
                    grid.update(this._job);
                    this.attemptGridSave(grid);
                }
                if (model == null || grid.getScoringInfos() == null || !this._hyperSpaceWalker.stopEarly(model, grid.getScoringInfos())) continue;
                Log.info("Convergence detected based on simple moving average of the loss function. Grid building completed.");
                break;
            }
            Log.info("For grid: " + grid._key + " built: " + grid.getModelCount() + " models.");
        }
        finally {
            grid.unlock(this._job);
        }
    }

    private void reconcileMaxRuntime(Key<Grid<MP>> gridKey, Model.Parameters params) {
        double grid_max_runtime_secs = (double)this._job._max_runtime_msecs / 1000.0;
        double time_remaining_secs = this.remainingTimeSecs();
        if (grid_max_runtime_secs > 0.0) {
            Log.info("Grid time is limited to: " + grid_max_runtime_secs + " for grid: " + gridKey + ". Remaining time is: " + time_remaining_secs);
            if (time_remaining_secs < 0.0) {
                Log.info("Grid max_runtime_secs of " + grid_max_runtime_secs + " secs has expired; stopping early.");
                throw new Job.JobCancelledException();
            }
        }
        if (params._max_runtime_secs > 0.0) {
            double was = params._max_runtime_secs;
            params._max_runtime_secs = Math.min(params._max_runtime_secs, time_remaining_secs);
            Log.info("Due to the grid time limit, changing model max runtime from: " + was + " secs to: " + params._max_runtime_secs + " secs.");
        } else {
            params._max_runtime_secs = time_remaining_secs;
            Log.info("Due to the grid time limit, changing model max runtime to: " + params._max_runtime_secs + " secs.");
        }
    }

    private void attemptGridSave(Grid grid) {
        String whereToExport = ((Model.Parameters)this._hyperSpaceWalker.getParams())._export_checkpoints_dir;
        if (whereToExport == null) {
            return;
        }
        try {
            grid.exportBinary(whereToExport);
        }
        catch (IOException e) {
            Log.warn(String.format("Could not save grid '%s' to location '%s'", grid._key.toString(), whereToExport));
        }
    }

    private Model buildModel(MP params, Grid<MP> grid, int paramsIdx, String protoModelKey) {
        Key[] modelKeys;
        final long checksum = ((Model.Parameters)params).checksum(IGNORED_FIELDS_PARAM_HASH);
        Key<Model> key = grid.getModelKey(checksum);
        if (key != null) {
            if (DKV.get(key) == null) {
                Log.info("GridSearch.buildModel(): model with these parameters was built but removed, rebuilding; checksum: " + checksum);
            } else {
                Log.info("GridSearch.buildModel(): model with these parameters already exists, skipping; checksum: " + checksum);
                return key.get();
            }
        }
        if ((modelKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter(){

            @Override
            public boolean filter(KeySnapshot.KeyInfo k) {
                if (!Value.isSubclassOf(k._type, Model.class)) {
                    return false;
                }
                Model m = (Model)k._key.get();
                if (m == null || m._parms == null) {
                    return false;
                }
                try {
                    return ((Model.Parameters)m._parms).checksum(IGNORED_FIELDS_PARAM_HASH) == checksum;
                }
                catch (H2OConcurrentModificationException e) {
                    Log.warn("GridSearch encountered concurrent modification while searching DKV", e);
                    return false;
                }
                catch (RuntimeException e) {
                    Throwable ex = e;
                    boolean concurrentModification = false;
                    while (ex.getCause() != null) {
                        if (!((ex = ex.getCause()) instanceof H2OConcurrentModificationException)) continue;
                        concurrentModification = true;
                        break;
                    }
                    if (!concurrentModification) {
                        throw e;
                    }
                    Log.warn("GridSearch encountered concurrent modification while searching DKV", e);
                    return false;
                }
            }
        }).keys()).length > 0) {
            grid.putModel(checksum, modelKeys[0]);
            return (Model)modelKeys[0].get();
        }
        Key<Model> result = Key.make(protoModelKey + paramsIdx);
        assert (grid.getModel(params) == null);
        Model m = ModelBuilder.trainModelNested(this._job, result, params, null);
        grid.putModel(checksum, result);
        return m;
    }

    protected static Key<Grid> gridKeyName(String modelName, Frame fr) {
        if (fr == null || fr._key == null) {
            throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
        }
        return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId(""));
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams, ModelParametersBuilderFactory<MP> paramsBuilderFactory, HyperSpaceSearchCriteria searchCriteria, int parallelism) {
        return GridSearch.startGridSearch(destKey, HyperSpaceWalker.BaseWalker.WalkerFactory.create(params, hyperParams, paramsBuilderFactory, searchCriteria), parallelism);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams) {
        return GridSearch.startGridSearch(destKey, params, hyperParams, new SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), 1);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams, int parallelism) {
        return GridSearch.startGridSearch(destKey, params, hyperParams, new SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), parallelism);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, HyperSpaceWalker<MP, ?> hyperSpaceWalker, int parallelism) {
        MP params = hyperSpaceWalker.getParams();
        Key<Grid> gridKey = destKey != null ? destKey : GridSearch.gridKeyName(((Model.Parameters)params).algoName(), ((Model.Parameters)params).train());
        return new GridSearch<MP>(gridKey, hyperSpaceWalker, parallelism).start();
    }

    public static int getParallelismLevel(int parallelism) {
        if (parallelism < 0) {
            throw new IllegalArgumentException(String.format("Grid search parallelism level must be >= 0. Give value is '%d'.", parallelism));
        }
        if (parallelism == 0) {
            return GridSearch.getAdaptiveParallelism();
        }
        return parallelism;
    }

    public static int getAdaptiveParallelism() {
        return 2 * H2O.NUMCPUS;
    }

    public static class SimpleParametersBuilderFactory<MP extends Model.Parameters>
    implements ModelParametersBuilderFactory<MP> {
        @Override
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new SimpleParamsBuilder<MP>(initialParams);
        }

        @Override
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.CONSISTENT;
        }

        public static class SimpleParamsBuilder<MP extends Model.Parameters>
        implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
            private final MP params;

            public SimpleParamsBuilder(MP initialParams) {
                this.params = initialParams;
            }

            @Override
            public ModelParametersBuilderFactory.ModelParametersBuilder<MP> set(String name, Object value) {
                PojoUtils.setField(this.params, name, value, PojoUtils.FieldNaming.CONSISTENT);
                return this;
            }

            @Override
            public MP build() {
                return this.params;
            }
        }
    }

    private class ModelFeeder<MP extends Model.Parameters, D extends ModelFeeder>
    extends ParallelModelBuilder.ParallelModelBuilderCallback<D> {
        private final HyperSpaceWalker.HyperSpaceIterator<MP> hyperspaceIterator;
        private final Grid grid;
        private final Lock parallelSearchGridLock = new ReentrantLock();

        public ModelFeeder(HyperSpaceWalker.HyperSpaceIterator<MP> hyperspaceIterator, Grid grid) {
            this.hyperspaceIterator = hyperspaceIterator;
            this.grid = grid;
        }

        @Override
        public void onBuildSuccess(Model finishedModel, ParallelModelBuilder parallelModelBuilder) {
            try {
                this.parallelSearchGridLock.lock();
                this.constructScoringInfo(finishedModel);
                this.grid.putModel(((Model.Parameters)finishedModel._parms).checksum(IGNORED_FIELDS_PARAM_HASH), finishedModel._key);
                GridSearch.this._job.update(1L);
                this.grid.update(GridSearch.this._job);
                GridSearch.this.attemptGridSave(this.grid);
            }
            finally {
                this.parallelSearchGridLock.unlock();
            }
            this.attemptBuildNextModel(parallelModelBuilder, finishedModel);
        }

        @Override
        public void onBuildFailure(ParallelModelBuilder.ModelBuildFailure modelBuildFailure, ParallelModelBuilder parallelModelBuilder) {
            this.parallelSearchGridLock.lock();
            try {
                this.grid.appendFailedModelParameters(null, modelBuildFailure.getParameters(), modelBuildFailure.getThrowable());
            }
            finally {
                this.parallelSearchGridLock.unlock();
            }
            this.attemptBuildNextModel(parallelModelBuilder, null);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void attemptBuildNextModel(ParallelModelBuilder parallelModelBuilder, Model previousModel) {
            try {
                this.parallelSearchGridLock.lock();
                MP nextModelParams = this.getNextModelParams(this.hyperspaceIterator, previousModel, this.grid);
                if (nextModelParams != null && this.isThereEnoughTime() && !GridSearch.this._job.stop_requested() && !GridSearch.this._hyperSpaceWalker.stopEarly(previousModel, this.grid.getScoringInfos())) {
                    GridSearch.this.reconcileMaxRuntime(this.grid._key, nextModelParams);
                    parallelModelBuilder.run(Collections.singletonList(ModelBuilder.make(nextModelParams)));
                } else {
                    parallelModelBuilder.noMoreModels();
                }
            }
            finally {
                this.parallelSearchGridLock.unlock();
            }
        }

        private void constructScoringInfo(Model model) {
            ScoringInfo scoringInfo = new ScoringInfo();
            scoringInfo.time_stamp_ms = System.currentTimeMillis();
            model.fillScoringInfo(scoringInfo);
            this.grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, this.grid.getScoringInfos()));
            ScoringInfo.sort(this.grid.getScoringInfos(), GridSearch.this.sortingMetric());
        }

        private boolean isThereEnoughTime() {
            boolean enoughTime;
            boolean bl = enoughTime = GridSearch.this.remainingTimeSecs() > 0.0;
            if (!enoughTime) {
                Log.info("Grid max_runtime_secs of " + GridSearch.this.maxRuntimeSecs() + " secs has expired; stopping early.");
            }
            return enoughTime;
        }

        private MP getNextModelParams(HyperSpaceWalker.HyperSpaceIterator<MP> hyperSpaceIterator, Model model, Grid grid) {
            Model.Parameters params = null;
            while (params == null && hyperSpaceIterator.hasNext(model)) {
                params = (Model.Parameters)hyperSpaceIterator.nextModelParameters(model);
                Key<Model> modelKey = grid.getModelKey(params.checksum(IGNORED_FIELDS_PARAM_HASH));
                if (modelKey == null) continue;
                params = null;
            }
            return (MP)params;
        }
    }
}

