package hex.segments;

import hex.Model;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Frame;
import water.rapids.ast.prims.mungers.AstGroup;
import water.util.Log;

/* loaded from: input_file:hex/segments/SegmentModelsBuilder.class */
public class SegmentModelsBuilder {
    private static final AtomicLong nextSegmentModelsNum = new AtomicLong(0);
    private final SegmentModelsParameters _parms;
    private final Model.Parameters _blueprint_parms;

    /* loaded from: input_file:hex/segments/SegmentModelsBuilder$MultiNodeRunner.class */
    private static class MultiNodeRunner extends MRTask<MultiNodeRunner> {
        final LocalSequentialSegmentModelsBuilder _builder;
        final SegmentModels _segment_models;
        final int _parallelism;
        SegmentModelsStats _stats;

        private MultiNodeRunner(LocalSequentialSegmentModelsBuilder localSequentialSegmentModelsBuilder, SegmentModels segmentModels, int i) {
            this._builder = localSequentialSegmentModelsBuilder;
            this._segment_models = segmentModels;
            this._parallelism = i;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            if (this._parallelism == 1) {
                this._stats = this._builder.buildModels(this._segment_models);
            } else {
                ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this._parallelism);
                this._stats = (SegmentModelsStats) Stream.generate(() -> {
                    return () -> {
                        return this._builder.m3719clone().buildModels(this._segment_models);
                    };
                }).limit(this._parallelism).map(callable -> {
                    return newFixedThreadPool.submit(callable);
                }).map(future -> {
                    try {
                        return (SegmentModelsStats) future.get();
                    } catch (InterruptedException | ExecutionException e) {
                        throw new RuntimeException("Failed to build segment-models", e);
                    }
                }).reduce((segmentModelsStats, segmentModelsStats2) -> {
                    segmentModelsStats.reduce(segmentModelsStats2);
                    return segmentModelsStats;
                }).get();
            }
            Log.info("Finished per-segment model building on node ", H2O.SELF, "; summary: ", this._stats);
        }

        @Override // water.MRTask
        public void reduce(MultiNodeRunner multiNodeRunner) {
            this._stats.reduce(multiNodeRunner._stats);
        }
    }

    /* loaded from: input_file:hex/segments/SegmentModelsBuilder$SegmentModelsBuilderTask.class */
    private class SegmentModelsBuilderTask extends H2O.H2OCountedCompleter<SegmentModelsBuilderTask> {
        private final Job<SegmentModels> _job;
        private final Frame _segments;
        private final Frame _full_train;
        private final Frame _full_valid;
        private final Key _counter_key;
        private final int _parallelism;

        private SegmentModelsBuilderTask(Job<SegmentModels> job, Frame frame, Key<Frame> key, Key<Frame> key2, int i) {
            this._job = job;
            this._segments = frame;
            this._full_train = reorderColumns(key);
            this._full_valid = reorderColumns(key2);
            this._counter_key = Key.make();
            this._parallelism = i;
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            try {
                SegmentModelsBuilder.this._blueprint_parms.read_lock_frames(this._job);
                Log.info("Finished per-segment model building; summary: ", new MultiNodeRunner(new LocalSequentialSegmentModelsBuilder(this._job, SegmentModelsBuilder.this._blueprint_parms, this._segments, this._full_train, this._full_valid, new WorkAllocator(this._counter_key, this._segments.numRows())), SegmentModels.make(this._job._result, this._segments), this._parallelism).doAllNodes()._stats);
                SegmentModelsBuilder.this._blueprint_parms.read_unlock_frames(this._job);
                if (this._segments._key == null) {
                    this._segments.remove();
                }
                DKV.remove(this._counter_key);
                tryComplete();
            } catch (Throwable th) {
                SegmentModelsBuilder.this._blueprint_parms.read_unlock_frames(this._job);
                if (this._segments._key == null) {
                    this._segments.remove();
                }
                DKV.remove(this._counter_key);
                throw th;
            }
        }

        private Frame reorderColumns(Key<Frame> key) {
            if (key == null) {
                return null;
            }
            Frame frame = key.get();
            if (frame == null) {
                throw new IllegalStateException("Key " + key + " doesn't point to an existing Frame.");
            }
            Frame frame2 = new Frame(frame);
            Frame add = new Frame(this._segments.names(), frame2.vecs(this._segments.names())).add(frame2.remove(this._segments.names()));
            add._key = frame._key;
            return add;
        }
    }

    /* loaded from: input_file:hex/segments/SegmentModelsBuilder$SegmentModelsParameters.class */
    public static class SegmentModelsParameters extends Iced<SegmentModelsParameters> {
        Key<SegmentModels> _segment_models_id;
        Key<Frame> _segments;
        String[] _segment_columns;
        int _parallelism = 1;
    }

    public SegmentModelsBuilder(SegmentModelsParameters segmentModelsParameters, Model.Parameters parameters) {
        this._parms = segmentModelsParameters;
        this._blueprint_parms = parameters;
    }

    public Job<SegmentModels> buildSegmentModels() {
        if (this._parms._parallelism <= 0) {
            throw new IllegalArgumentException("Parameter `parallelism` has to be a positive number, received=" + this._parms._parallelism);
        }
        Frame validateSegmentsFrame = this._parms._segments != null ? validateSegmentsFrame(this._parms._segments, this._parms._segment_columns) : makeSegmentsFrame(this._blueprint_parms._train, this._parms._segment_columns);
        Job job = new Job(makeDestKey(), SegmentModels.class.getName(), this._blueprint_parms.algoName());
        return job.start(new SegmentModelsBuilderTask(job, validateSegmentsFrame, this._blueprint_parms._train, this._blueprint_parms._valid, this._parms._parallelism), validateSegmentsFrame.numRows());
    }

    private Frame makeSegmentsFrame(Key<Frame> key, String[] strArr) {
        Frame validateSegmentsFrame = validateSegmentsFrame(key, strArr);
        return new AstGroup().performGroupingWithAggregations(validateSegmentsFrame, validateSegmentsFrame.find(strArr), new AstGroup.AGG[0]).getFrame();
    }

    private Key<SegmentModels> makeDestKey() {
        return this._parms._segment_models_id != null ? this._parms._segment_models_id : Key.make(H2O.calcNextUniqueObjectId("segment_models", nextSegmentModelsNum, this._blueprint_parms.algoName()));
    }

    private static Frame validateSegmentsFrame(Key<Frame> key, String[] strArr) {
        Frame frame = key.get();
        if (frame == null) {
            throw new IllegalStateException("Frame `" + key + "` doesn't exist.");
        }
        List list = (List) Stream.of((Object[]) (strArr != null ? strArr : frame.names())).filter(str -> {
            return (frame.vec(str).isCategorical() || frame.vec(str).isInt()) ? false : true;
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            return frame;
        }
        throw new IllegalStateException("Columns to segment-by can only be categorical and integer of type, invalid columns: " + list);
    }
}
