package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.StackedEnsembleModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import water.DKV;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/ensemble/StackedEnsemble.class */
public class StackedEnsemble extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsemble$StackedEnsembleDriver.class */
    public class StackedEnsembleDriver extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput>.Driver {
        private StackedEnsembleDriver() {
            super();
        }

        private Frame prepareLevelOneFrame(String str, Model[] modelArr, Frame[] frameArr, Frame frame) {
            if (null == modelArr) {
                throw new H2OIllegalArgumentException("Base models array is null.");
            }
            if (null == frameArr) {
                throw new H2OIllegalArgumentException("Base model predictions array is null.");
            }
            if (modelArr.length == 0) {
                throw new H2OIllegalArgumentException("Base models array is empty.");
            }
            if (frameArr.length == 0) {
                throw new H2OIllegalArgumentException("Base model predictions array is empty.");
            }
            if (modelArr.length != frameArr.length) {
                throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
            }
            if (null == str) {
                str = "levelone_" + StackedEnsemble.this._model._key.toString();
            }
            Frame frame2 = new Frame((Key<Frame>) Key.make(str));
            for (int i = 0; i < modelArr.length; i++) {
                Model model = modelArr[i];
                Frame frame3 = frameArr[i];
                if (null == model) {
                    Log.warn("Failed to find base model; skipping: " + modelArr[i]);
                } else if (null == frame3) {
                    Log.warn("Failed to find base model " + model + " predictions; skipping: " + frame3._key);
                } else {
                    StackedEnsemble.addModelPredictionsToLevelOneFrame(model, frame3, frame2);
                }
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_fold_column != null) {
                frame2.add(((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_fold_column, frame.vec(((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_fold_column));
            }
            frame2.add(StackedEnsemble.this._model.responseColumn, frame.vec(StackedEnsemble.this._model.responseColumn));
            Frame frame4 = (Frame) DKV.getGet(frame2._key);
            if (frame4 != null && (frame4 instanceof Frame)) {
                frame4.removeAll();
                frame4.write_lock(StackedEnsemble.this._job);
                frame4.update(StackedEnsemble.this._job);
                frame4.unlock(StackedEnsemble.this._job);
            }
            frame2.delete_and_lock(StackedEnsemble.this._job);
            frame2.unlock(StackedEnsemble.this._job);
            Log.info("Finished creating \"level one\" frame for stacking: " + frame2.toString());
            DKV.put(frame2);
            return frame2;
        }

        private Frame prepareTrainingLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
            String str = "levelone_training_" + StackedEnsemble.this._model._key.toString();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Key<Model> key : stackedEnsembleParameters._base_models) {
                Model model = (Model) DKV.getGet(key);
                if (null == model) {
                    throw new H2OIllegalArgumentException("Failed to find base model: " + key);
                }
                if (null == model._output._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                Frame frame = (Frame) DKV.getGet(model._output._cross_validation_holdout_predictions_frame_id);
                if (null == frame) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");
                }
                arrayList.add(model);
                if (model._output.isMultinomialClassifier()) {
                    ArrayList arrayList3 = new ArrayList(Arrays.asList(frame.names()));
                    arrayList3.remove("predict");
                    arrayList2.add(frame.subframe((String[]) arrayList3.toArray(new String[0])));
                } else {
                    arrayList2.add(frame);
                }
            }
            return prepareLevelOneFrame(str, (Model[]) arrayList.toArray(new Model[0]), (Frame[]) arrayList2.toArray(new Frame[0]), ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).train());
        }

        private Key<Frame> buildPredsKey(Key key, long j, Key key2, long j2) {
            return Key.make("preds_" + j + "_on_" + j2);
        }

        private Key<Frame> buildPredsKey(Model model, Frame frame) {
            if (frame == null || model == null) {
                return null;
            }
            return buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
        }

        private Frame prepareValidationLevelOneFrame(String str, Key<Model>[] keyArr, Frame frame) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Key<Model> key : keyArr) {
                Model model = (Model) DKV.getGet(key);
                if (null == model) {
                    throw new H2OIllegalArgumentException("Failed to find base model: " + key);
                }
                Frame score = model.score(frame, buildPredsKey(model, frame).toString());
                arrayList.add(model);
                if (model._output.isMultinomialClassifier()) {
                    ArrayList arrayList3 = new ArrayList(Arrays.asList(score.names()));
                    arrayList3.remove("predict");
                    arrayList2.add(score.subframe((String[]) arrayList3.toArray(new String[0])));
                } else {
                    arrayList2.add(score);
                }
            }
            Frame prepareLevelOneFrame = prepareLevelOneFrame(str, (Model[]) arrayList.toArray(new Model[0]), (Frame[]) arrayList2.toArray(new Frame[0]), frame);
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame) it.next(), prepareLevelOneFrame);
            }
            return prepareLevelOneFrame;
        }

        @Override // hex.ModelBuilder.Driver
        public void computeImpl() {
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
            StackedEnsemble.this._model.checkAndInheritModelProperties();
            Frame prepareTrainingLevelOneFrame = prepareTrainingLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms);
            Frame frame = null;
            if (((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).valid() != null) {
                frame = prepareValidationLevelOneFrame("levelone_validation_" + StackedEnsemble.this._model._key.toString(), ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._base_models, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms).valid());
            }
            StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm metalearnerAlgorithm = ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_algorithm;
            StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm actualMetalearnerAlgo = StackedEnsemble.this.getActualMetalearnerAlgo(metalearnerAlgorithm);
            if (actualMetalearnerAlgo == null) {
                throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + metalearnerAlgorithm + " but must be one of 'glm', 'gbm', 'randomForest', or 'deeplearning'.");
            }
            Key make = Key.make("metalearner_" + metalearnerAlgorithm + "_" + StackedEnsemble.this._model._key);
            Metalearner metalearner = new Metalearner(prepareTrainingLevelOneFrame, frame, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_parameters, StackedEnsemble.this._model, StackedEnsemble.this._job, make, new Job(make, ModelBuilder.javaName(actualMetalearnerAlgo.toString()), "StackingEnsemble metalearner (" + metalearnerAlgorithm + ")"), (StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._metalearner_parameters != null, ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._model._parms)._seed);
            switch (metalearnerAlgorithm) {
                case AUTO:
                    metalearner.computeAutoMetalearner();
                    return;
                case gbm:
                    metalearner.computeGBMMetalearner();
                    return;
                case drf:
                    metalearner.computeDRFMetalearner();
                    return;
                case glm:
                    metalearner.computeGLMMetalearner();
                    return;
                case deeplearning:
                    metalearner.computeDeepLearningMetalearner();
                    return;
                default:
                    throw new UnsupportedOperationException("Unknown meta-learner algo:" + metalearnerAlgorithm);
            }
        }
    }

    public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
        super(stackedEnsembleParameters);
        init(false);
    }

    public StackedEnsemble(boolean z) {
        super(new StackedEnsembleModel.StackedEnsembleParameters(), z);
    }

    @Override // hex.ModelBuilder
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    @Override // hex.ModelBuilder
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    @Override // hex.ModelBuilder
    public boolean isSupervised() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.ModelBuilder
    public StackedEnsembleDriver trainModelImpl() {
        StackedEnsembleDriver stackedEnsembleDriver = new StackedEnsembleDriver();
        this._driver = stackedEnsembleDriver;
        return stackedEnsembleDriver;
    }

    @Override // hex.ModelBuilder
    public boolean haveMojo() {
        return true;
    }

    public static void addModelPredictionsToLevelOneFrame(Model model, Frame frame, Frame frame2) {
        if (model._output.isBinomialClassifier()) {
            frame2.add(model._key.toString(), frame.vec(2));
        } else if (model._output.isMultinomialClassifier()) {
            frame2.add(frame);
        } else {
            if (model._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + model._key);
            }
            if (!model._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + model._key);
            }
            frame2.add(model._key.toString(), frame.vec("predict"));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm getActualMetalearnerAlgo(StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm metalearnerAlgorithm) {
        switch (metalearnerAlgorithm) {
            case AUTO:
                return StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm;
            case gbm:
            case drf:
            case glm:
            case deeplearning:
                return metalearnerAlgorithm;
            default:
                return null;
        }
    }
}
