package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glm.GLM;
import hex.glm.GLMModel;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:hex/tree/PlattScalingHelper.class */
public class PlattScalingHelper {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/PlattScalingHelper$ModelBuilderWithCalibration.class */
    public interface ModelBuilderWithCalibration<M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> {
        ModelBuilder getModelBuilder();

        Frame getCalibrationFrame();

        void setCalibrationFrame(Frame frame);
    }

    /* loaded from: input_file:hex/tree/PlattScalingHelper$OutputWithCalibration.class */
    public interface OutputWithCalibration {
        ModelCategory getModelCategory();

        GLMModel calibrationModel();
    }

    /* loaded from: input_file:hex/tree/PlattScalingHelper$ParamsWithCalibration.class */
    public interface ParamsWithCalibration {
        Model.Parameters getParams();

        Frame getCalibrationFrame();

        boolean calibrateModel();
    }

    public static void initCalibration(ModelBuilderWithCalibration modelBuilderWithCalibration, ParamsWithCalibration paramsWithCalibration, boolean z) {
        Frame calibrationFrame = paramsWithCalibration.getCalibrationFrame();
        if (calibrationFrame != null) {
            if (!paramsWithCalibration.calibrateModel()) {
                modelBuilderWithCalibration.getModelBuilder().warn("_calibration_frame", "Calibration frame was specified but calibration was not requested.");
            }
            modelBuilderWithCalibration.setCalibrationFrame(modelBuilderWithCalibration.getModelBuilder().init_adaptFrameToTrain(calibrationFrame, "Calibration Frame", "_calibration_frame", z));
        }
        if (paramsWithCalibration.calibrateModel()) {
            if (modelBuilderWithCalibration.getModelBuilder().nclasses() != 2) {
                modelBuilderWithCalibration.getModelBuilder().error("_calibrate_model", "Model calibration is only currently supported for binomial models.");
            }
            if (calibrationFrame == null) {
                modelBuilderWithCalibration.getModelBuilder().error("_calibrate_model", "Calibration frame was not specified.");
            }
        }
    }

    public static <M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> GLMModel buildCalibrationModel(ModelBuilderWithCalibration<M, P, O> modelBuilderWithCalibration, ParamsWithCalibration paramsWithCalibration, Job job, M m) {
        Key make = Key.make();
        try {
            Scope.enter();
            job.update(0L, "Calibrating probabilities");
            Frame calibrationFrame = modelBuilderWithCalibration.getCalibrationFrame();
            Vec vec = paramsWithCalibration.getParams()._weights_column != null ? calibrationFrame.vec(paramsWithCalibration.getParams()._weights_column) : null;
            Frame frame = new Frame((Key<Frame>) make, new String[]{"p", "response"}, new Vec[]{Scope.track(m.score(calibrationFrame, null, job, false)).vec(1), calibrationFrame.vec(paramsWithCalibration.getParams()._response_column)});
            if (vec != null) {
                frame.add("weights", vec);
            }
            DKV.put(frame);
            Key make2 = Key.make();
            GLM glm = (GLM) ModelBuilder.make("GLM", new Job(make2, ModelBuilder.javaName("glm"), "Platt Scaling (GLM)"), make2);
            ((GLMModel.GLMParameters) glm._parms)._intercept = true;
            ((GLMModel.GLMParameters) glm._parms)._response_column = "response";
            ((GLMModel.GLMParameters) glm._parms)._train = frame._key;
            ((GLMModel.GLMParameters) glm._parms)._family = GLMModel.GLMParameters.Family.binomial;
            ((GLMModel.GLMParameters) glm._parms)._lambda = new double[]{CMAESOptimizer.DEFAULT_STOPFITNESS};
            if (vec != null) {
                ((GLMModel.GLMParameters) glm._parms)._weights_column = "weights";
            }
            GLMModel gLMModel = glm.trainModel().get();
            Scope.exit(new Key[0]);
            DKV.remove(make);
            return gLMModel;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            DKV.remove(make);
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Frame postProcessPredictions(Frame frame, Job job, OutputWithCalibration outputWithCalibration) {
        if (outputWithCalibration.calibrationModel() == null) {
            return frame;
        }
        if (outputWithCalibration.getModelCategory() != ModelCategory.Binomial) {
            throw H2O.unimpl("Calibration is only supported for binomial models");
        }
        Key key = job != null ? job._key : null;
        Key make = Key.make();
        Keyed keyed = null;
        try {
            Frame score = outputWithCalibration.calibrationModel().score(new Frame((Key<Frame>) make, new String[]{"p"}, new Vec[]{frame.vec(1)}));
            if (!$assertionsDisabled && score._names.length != 3) {
                throw new AssertionError();
            }
            Vec[] remove = score.remove(new int[]{1, 2});
            frame.write_lock((Key<Job>) key);
            for (int i = 0; i < remove.length; i++) {
                frame.add("cal_" + frame.name(1 + i), remove[i]);
            }
            Frame update = frame.update((Key<Job>) key);
            frame.unlock((Key<Job>) key);
            DKV.remove(make);
            if (score != null) {
                score.remove();
            }
            return update;
        } catch (Throwable th) {
            frame.unlock((Key<Job>) key);
            DKV.remove(make);
            if (0 != 0) {
                keyed.remove();
            }
            throw th;
        }
    }

    static {
        $assertionsDisabled = !PlattScalingHelper.class.desiredAssertionStatus();
    }
}
