package hex.grid;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.ParallelModelBuilder;
import hex.ScoreKeeper;
import hex.ScoringInfo;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import jsr166y.CountedCompleter;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.KeySnapshot;
import water.Keyed;
import water.Value;
import water.exceptions.H2OConcurrentModificationException;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;

/* loaded from: input_file:hex/grid/GridSearch.class */
public final class GridSearch<MP extends Model.Parameters> extends Keyed<GridSearch> {
    public final Key<Grid> _result;
    public final Job<Grid> _job;
    public final int _parallelism;
    private final transient HyperSpaceWalker<MP, ?> _hyperSpaceWalker;
    private static final Set<String> IGNORED_FIELDS_PARAM_HASH;
    public static final int ADAPTIVE_PARALLELISM_LEVEL = 0;
    public static final int SEQUENTIAL_MODEL_BUILDING = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/grid/GridSearch$ModelFeeder.class */
    public class ModelFeeder<MP extends Model.Parameters, D extends ModelFeeder> extends ParallelModelBuilder.ParallelModelBuilderCallback<D> {
        private final HyperSpaceWalker.HyperSpaceIterator<MP> hyperspaceIterator;
        private final Grid grid;
        private final Lock parallelSearchGridLock = new ReentrantLock();

        public ModelFeeder(HyperSpaceWalker.HyperSpaceIterator<MP> hyperSpaceIterator, Grid grid) {
            this.hyperspaceIterator = hyperSpaceIterator;
            this.grid = grid;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // hex.ParallelModelBuilder.ParallelModelBuilderCallback
        public void onBuildSuccess(Model model, ParallelModelBuilder parallelModelBuilder) {
            try {
                this.parallelSearchGridLock.lock();
                constructScoringInfo(model);
                this.grid.putModel(model._parms.checksum(GridSearch.IGNORED_FIELDS_PARAM_HASH), model._key);
                GridSearch.this._job.update(1L);
                this.grid.update(GridSearch.this._job);
                GridSearch.this.attemptGridSave(this.grid);
                attemptBuildNextModel(parallelModelBuilder, model);
            } finally {
                this.parallelSearchGridLock.unlock();
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // hex.ParallelModelBuilder.ParallelModelBuilderCallback
        public void onBuildFailure(ParallelModelBuilder.ModelBuildFailure modelBuildFailure, ParallelModelBuilder parallelModelBuilder) {
            this.parallelSearchGridLock.lock();
            try {
                this.grid.appendFailedModelParameters((Key<Model>) null, (Key) modelBuildFailure.getParameters(), modelBuildFailure.getThrowable());
                attemptBuildNextModel(parallelModelBuilder, null);
            } finally {
                this.parallelSearchGridLock.unlock();
            }
        }

        private void attemptBuildNextModel(ParallelModelBuilder parallelModelBuilder, Model model) {
            try {
                this.parallelSearchGridLock.lock();
                MP nextModelParams = getNextModelParams(this.hyperspaceIterator, model, this.grid);
                if (nextModelParams == null || !isThereEnoughTime() || GridSearch.this._job.stop_requested() || GridSearch.this._hyperSpaceWalker.stopEarly(model, this.grid.getScoringInfos())) {
                    parallelModelBuilder.noMoreModels();
                } else {
                    GridSearch.this.reconcileMaxRuntime(this.grid._key, nextModelParams);
                    parallelModelBuilder.run(Collections.singletonList(ModelBuilder.make(nextModelParams)));
                }
            } finally {
                this.parallelSearchGridLock.unlock();
            }
        }

        private void constructScoringInfo(Model model) {
            ScoringInfo scoringInfo = new ScoringInfo();
            scoringInfo.time_stamp_ms = System.currentTimeMillis();
            model.fillScoringInfo(scoringInfo);
            this.grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, this.grid.getScoringInfos()));
            ScoringInfo.sort(this.grid.getScoringInfos(), GridSearch.this.sortingMetric());
        }

        private boolean isThereEnoughTime() {
            boolean z = GridSearch.this.remainingTimeSecs() > 0.0d;
            if (!z) {
                Log.info("Grid max_runtime_secs of " + GridSearch.this.maxRuntimeSecs() + " secs has expired; stopping early.");
            }
            return z;
        }

        private MP getNextModelParams(HyperSpaceWalker.HyperSpaceIterator<MP> hyperSpaceIterator, Model model, Grid grid) {
            MP mp = null;
            while (mp == null && hyperSpaceIterator.hasNext(model)) {
                mp = hyperSpaceIterator.nextModelParameters(model);
                if (grid.getModelKey(mp.checksum(GridSearch.IGNORED_FIELDS_PARAM_HASH)) != null) {
                    mp = null;
                }
            }
            return mp;
        }
    }

    /* loaded from: input_file:hex/grid/GridSearch$SimpleParametersBuilderFactory.class */
    public static class SimpleParametersBuilderFactory<MP extends Model.Parameters> implements ModelParametersBuilderFactory<MP> {

        /* loaded from: input_file:hex/grid/GridSearch$SimpleParametersBuilderFactory$SimpleParamsBuilder.class */
        public static class SimpleParamsBuilder<MP extends Model.Parameters> implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
            private final MP params;

            public SimpleParamsBuilder(MP mp) {
                this.params = mp;
            }

            @Override // hex.ModelParametersBuilderFactory.ModelParametersBuilder
            public ModelParametersBuilderFactory.ModelParametersBuilder<MP> set(String str, Object obj) {
                PojoUtils.setField(this.params, str, obj, PojoUtils.FieldNaming.CONSISTENT);
                return this;
            }

            @Override // hex.ModelParametersBuilderFactory.ModelParametersBuilder
            public MP build() {
                return this.params;
            }
        }

        @Override // hex.ModelParametersBuilderFactory
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP mp) {
            return new SimpleParamsBuilder(mp);
        }

        @Override // hex.ModelParametersBuilderFactory
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.CONSISTENT;
        }
    }

    private GridSearch(Key<Grid> key, HyperSpaceWalker<MP, ?> hyperSpaceWalker, int i) {
        if (!$assertionsDisabled && hyperSpaceWalker == null) {
            throw new AssertionError("Grid search needs to know how to walk around hyper space!");
        }
        this._hyperSpaceWalker = hyperSpaceWalker;
        this._result = key;
        this._job = new Job<>(key, Grid.class.getName(), hyperSpaceWalker.getParams().algoName() + " Grid Search");
        this._parallelism = i;
    }

    Job<Grid> start() {
        Grid grid;
        long maxHyperSpaceSize = this._hyperSpaceWalker.getMaxHyperSpaceSize();
        Log.info("Starting gridsearch: estimated size of search space = " + maxHyperSpaceSize);
        Keyed keyed = (Keyed) DKV.getGet(this._result);
        if (keyed == null) {
            grid = new Grid(this._result, this._hyperSpaceWalker.getParams(), this._hyperSpaceWalker.getHyperParamNames(), this._hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
            grid.delete_and_lock(this._job);
        } else {
            if (!(keyed instanceof Grid)) {
                throw new H2OIllegalArgumentException("Name conflict: tried to create a Grid using the ID of a non-Grid object that's already in H2O: " + this._job._result + "; it is a: " + keyed.getClass());
            }
            grid = (Grid) keyed;
            grid.clearNonRelatedFailures();
            Frame train = this._hyperSpaceWalker.getParams().train();
            Frame trainingFrame = grid.getTrainingFrame();
            if ((trainingFrame != null && !train._key.equals(trainingFrame._key)) || (trainingFrame != null && train.checksum() != trainingFrame.checksum())) {
                throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input");
            }
            grid.write_lock(this._job);
        }
        HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
        long j = 0;
        if (maxHyperSpaceSize <= 0 || maxModels() <= 0) {
            j = Long.MAX_VALUE;
        } else {
            while (it.hasNext(null)) {
                try {
                    MP nextModelParameters = it.nextModelParameters(null);
                    j += (nextModelParameters._nfolds > 0 ? nextModelParameters._nfolds + 1 : 1) * nextModelParameters.progressUnits();
                } catch (Throwable th) {
                }
            }
        }
        final Grid grid2 = grid;
        return this._job.start(new H2O.H2OCountedCompleter() { // from class: hex.grid.GridSearch.1
            @Override // water.H2O.H2OCountedCompleter
            public void compute2() {
                if (GridSearch.this._parallelism == 1) {
                    GridSearch.this.gridSearch(grid2);
                } else {
                    if (GridSearch.this._parallelism <= 1) {
                        throw new IllegalArgumentException(String.format("Grid search parallelism level must be >= 1. Give value is '%d'.", Integer.valueOf(GridSearch.this._parallelism)));
                    }
                    GridSearch.this.parallelGridSearch(grid2);
                }
                tryComplete();
            }

            @Override // jsr166y.CountedCompleter
            public boolean onExceptionalCompletion(Throwable th2, CountedCompleter countedCompleter) {
                Log.warn("GridSearch job " + GridSearch.this._job._description + " completed with exception: " + th2);
                return true;
            }
        }, j, maxRuntimeSecs());
    }

    public long getModelCount() {
        return this._hyperSpaceWalker.getMaxHyperSpaceSize();
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [hex.grid.HyperSpaceSearchCriteria] */
    /* JADX WARN: Type inference failed for: r0v6, types: [hex.grid.HyperSpaceSearchCriteria] */
    private long maxModels() {
        if (this._hyperSpaceWalker.search_criteria().stoppingCriteria() == null) {
            return 0L;
        }
        return this._hyperSpaceWalker.search_criteria().stoppingCriteria().getMaxModels();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r0v2, types: [hex.grid.HyperSpaceSearchCriteria] */
    /* JADX WARN: Type inference failed for: r0v6, types: [hex.grid.HyperSpaceSearchCriteria] */
    public double maxRuntimeSecs() {
        if (this._hyperSpaceWalker.search_criteria().stoppingCriteria() == null) {
            return 0.0d;
        }
        return this._hyperSpaceWalker.search_criteria().stoppingCriteria().getMaxRuntimeSecs();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double remainingTimeSecs() {
        if (this._job == null || this._job._max_runtime_msecs <= 0) {
            return Double.MAX_VALUE;
        }
        return ((this._job.start_time() + this._job._max_runtime_msecs) - System.currentTimeMillis()) / 1000.0d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r0v2, types: [hex.grid.HyperSpaceSearchCriteria] */
    /* JADX WARN: Type inference failed for: r0v6, types: [hex.grid.HyperSpaceSearchCriteria] */
    public ScoreKeeper.StoppingMetric sortingMetric() {
        return this._hyperSpaceWalker.search_criteria().stoppingCriteria() == null ? ScoreKeeper.StoppingMetric.AUTO : this._hyperSpaceWalker.search_criteria().stoppingCriteria().getStoppingMetric();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void parallelGridSearch(Grid<MP> grid) {
        HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
        ParallelModelBuilder parallelModelBuilder = new ParallelModelBuilder(new ModelFeeder(it, grid));
        ArrayList arrayList = new ArrayList();
        List<MP> initialModelParameters = initialModelParameters(this._parallelism, it);
        for (int i = 0; i < initialModelParameters.size(); i++) {
            MP mp = initialModelParameters.get(i);
            if (grid.getModelKey(mp.checksum(IGNORED_FIELDS_PARAM_HASH)) == null) {
                arrayList.add(ModelBuilder.make(mp));
            }
        }
        if (!arrayList.isEmpty()) {
            parallelModelBuilder.run(arrayList);
            parallelModelBuilder.join();
        }
        grid.update(this._job);
        attemptGridSave(grid);
        grid.unlock(this._job);
    }

    public List<MP> initialModelParameters(int i, HyperSpaceWalker.HyperSpaceIterator<MP> hyperSpaceIterator) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i && hyperSpaceIterator.hasNext(null); i2++) {
            arrayList.add(hyperSpaceIterator.nextModelParameters(null));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Code restructure failed: missing block: B:26:0x0196, code lost:
    
        water.util.Log.info("For grid: " + r7._key + " built: " + r7.getModelCount() + " models.");
     */
    /* JADX WARN: Code restructure failed: missing block: B:27:0x01cb, code lost:
    
        r7.unlock(r6._job);
     */
    /* JADX WARN: Code restructure failed: missing block: B:28:0x01e4, code lost:
    
        return;
     */
    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Multi-variable type inference failed */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void gridSearch(hex.grid.Grid<MP> r7) {
        /*
            Method dump skipped, instructions count: 485
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hex.grid.GridSearch.gridSearch(hex.grid.Grid):void");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void reconcileMaxRuntime(Key<Grid<MP>> key, Model.Parameters parameters) {
        double d = this._job._max_runtime_msecs / 1000.0d;
        double remainingTimeSecs = remainingTimeSecs();
        if (d > 0.0d) {
            Log.info("Grid time is limited to: " + d + " for grid: " + key + ". Remaining time is: " + remainingTimeSecs);
            if (remainingTimeSecs < 0.0d) {
                Log.info("Grid max_runtime_secs of " + d + " secs has expired; stopping early.");
                throw new Job.JobCancelledException();
            }
        }
        if (parameters._max_runtime_secs <= 0.0d) {
            parameters._max_runtime_secs = remainingTimeSecs;
            Log.info("Due to the grid time limit, changing model max runtime to: " + parameters._max_runtime_secs + " secs.");
        } else {
            double d2 = parameters._max_runtime_secs;
            parameters._max_runtime_secs = Math.min(parameters._max_runtime_secs, remainingTimeSecs);
            Log.info("Due to the grid time limit, changing model max runtime from: " + d2 + " secs to: " + parameters._max_runtime_secs + " secs.");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void attemptGridSave(Grid grid) {
        String str = this._hyperSpaceWalker.getParams()._export_checkpoints_dir;
        if (str == null) {
            return;
        }
        try {
            grid.exportBinary(str);
        } catch (IOException e) {
            Log.warn(String.format("Could not save grid '%s' to location '%s'", grid._key.toString(), str));
        }
    }

    private Model buildModel(MP mp, Grid<MP> grid, int i, String str) {
        final long checksum = mp.checksum(IGNORED_FIELDS_PARAM_HASH);
        Key<Model> modelKey = grid.getModelKey(checksum);
        if (modelKey != null) {
            if (DKV.get(modelKey) != null) {
                Log.info("GridSearch.buildModel(): model with these parameters already exists, skipping; checksum: " + checksum);
                return modelKey.get();
            }
            Log.info("GridSearch.buildModel(): model with these parameters was built but removed, rebuilding; checksum: " + checksum);
        }
        Key[] keys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter() { // from class: hex.grid.GridSearch.2
            @Override // water.KeySnapshot.KVFilter
            public boolean filter(KeySnapshot.KeyInfo keyInfo) {
                Model model;
                if (!Value.isSubclassOf(keyInfo._type, Model.class) || (model = (Model) keyInfo._key.get()) == null || model._parms == 0) {
                    return false;
                }
                try {
                    return model._parms.checksum(GridSearch.IGNORED_FIELDS_PARAM_HASH) == checksum;
                } catch (H2OConcurrentModificationException e) {
                    Log.warn("GridSearch encountered concurrent modification while searching DKV", e);
                    return false;
                } catch (RuntimeException e2) {
                    RuntimeException runtimeException = e2;
                    boolean z = false;
                    while (true) {
                        if (runtimeException.getCause() == null) {
                            break;
                        }
                        runtimeException = runtimeException.getCause();
                        if (runtimeException instanceof H2OConcurrentModificationException) {
                            z = true;
                            break;
                        }
                    }
                    if (!z) {
                        throw e2;
                    }
                    Log.warn("GridSearch encountered concurrent modification while searching DKV", e2);
                    return false;
                }
            }
        }).keys();
        if (keys.length > 0) {
            grid.putModel(checksum, keys[0]);
            return keys[0].get();
        }
        Key<Model> make = Key.make(str + i);
        if (!$assertionsDisabled && grid.getModel(mp) != null) {
            throw new AssertionError();
        }
        Model trainModelNested = ModelBuilder.trainModelNested(this._job, make, mp, null);
        grid.putModel(checksum, make);
        return trainModelNested;
    }

    protected static Key<Grid> gridKeyName(String str, Frame frame) {
        if (frame == null || frame._key == null) {
            throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
        }
        return Key.make("Grid_" + str + "_" + frame._key.toString() + H2O.calcNextUniqueModelId(""));
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> key, MP mp, Map<String, Object[]> map, ModelParametersBuilderFactory<MP> modelParametersBuilderFactory, HyperSpaceSearchCriteria hyperSpaceSearchCriteria, int i) {
        return startGridSearch(key, HyperSpaceWalker.BaseWalker.WalkerFactory.create(mp, map, modelParametersBuilderFactory, hyperSpaceSearchCriteria), i);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> key, MP mp, Map<String, Object[]> map) {
        return startGridSearch(key, mp, map, new SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), 1);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> key, MP mp, Map<String, Object[]> map, int i) {
        return startGridSearch(key, mp, map, new SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), i);
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> key, HyperSpaceWalker<MP, ?> hyperSpaceWalker, int i) {
        MP params = hyperSpaceWalker.getParams();
        return new GridSearch(key != null ? key : gridKeyName(params.algoName(), params.train()), hyperSpaceWalker, i).start();
    }

    public static int getParallelismLevel(int i) {
        if (i < 0) {
            throw new IllegalArgumentException(String.format("Grid search parallelism level must be >= 0. Give value is '%d'.", Integer.valueOf(i)));
        }
        return i == 0 ? getAdaptiveParallelism() : i;
    }

    public static int getAdaptiveParallelism() {
        return 2 * H2O.NUMCPUS;
    }

    static {
        $assertionsDisabled = !GridSearch.class.desiredAssertionStatus();
        IGNORED_FIELDS_PARAM_HASH = Collections.singleton("_export_checkpoints_dir");
    }
}
