package org.apache.spark.ml.h2o.algos;

import hex.FrameSplitter;
import hex.Model;
import hex.Model.Parameters;
import hex.genmodel.utils.DistributionFamily;
import java.io.IOException;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.h2o.algos.params.H2OAlgoParams;
import org.apache.spark.ml.h2o.models.H2OModelParams;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import water.DKV;
import water.H2O;
import water.Key;
import water.fvec.Frame;
import water.fvec.H2OFrame;
import water.support.H2OFrameSupport$;

/* compiled from: H2OAlgorithm.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0005f!B\u0001\u0003\u0003\u0003y!\u0001\u0004%3\u001f\u0006cwm\u001c:ji\"l'BA\u0002\u0005\u0003\u0015\tGnZ8t\u0015\t)a!A\u0002ie=T!a\u0002\u0005\u0002\u00055d'BA\u0005\u000b\u0003\u0015\u0019\b/\u0019:l\u0015\tYA\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u001b\u0005\u0019qN]4\u0004\u0001U\u0019\u0001\u0003M\f\u0014\t\u0001\t2%\u000b\t\u0004%M)R\"\u0001\u0004\n\u0005Q1!!C#ti&l\u0017\r^8s!\t1r\u0003\u0004\u0001\u0005\u000ba\u0001!\u0019A\r\u0003\u00035\u000b\"A\u0007\u0011\u0011\u0005mqR\"\u0001\u000f\u000b\u0003u\tQa]2bY\u0006L!a\b\u000f\u0003\u000f9{G\u000f[5oOB\u0019!#I\u000b\n\u0005\t2!!B'pI\u0016d\u0007C\u0001\u0013(\u001b\u0005)#B\u0001\u0014\u0007\u0003\u0011)H/\u001b7\n\u0005!*#AC'M/JLG/\u00192mKB\u0019!&L\u0018\u000e\u0003-R!\u0001\f\u0002\u0002\rA\f'/Y7t\u0013\tq3FA\u0007Ie=\u000bEnZ8QCJ\fWn\u001d\t\u0003-A\"Q!\r\u0001C\u0002I\u0012\u0011\u0001U\t\u00035M\u0002\"\u0001\u000e\u001e\u000f\u0005UBT\"\u0001\u001c\u000b\u0003]\n1\u0001[3y\u0013\tId'A\u0003N_\u0012,G.\u0003\u0002<y\tQ\u0001+\u0019:b[\u0016$XM]:\u000b\u0005e2\u0004\u0002\u0003 \u0001\u0005\u0003\u0005\u000b\u0011B \u0002\u0015A\f'/Y7fi\u0016\u00148\u000fE\u0002\u001c\u0001>J!!\u0011\u000f\u0003\r=\u0003H/[8o\u0011!\u0019\u0005AaA!\u0002\u0017!\u0015AC3wS\u0012,gnY3%cA\u0019Q\tS\u0018\u000e\u0003\u0019S!a\u0012\u000f\u0002\u000fI,g\r\\3di&\u0011\u0011J\u0012\u0002\t\u00072\f7o\u001d+bO\"A1\n\u0001B\u0002B\u0003-A*\u0001\u0006fm&$WM\\2fII\u00022!\u0012%\u0016\u0011!q\u0005A!A!\u0002\u0017y\u0015A\u00015d!\t\u0001&+D\u0001R\u0015\t)\u0001\"\u0003\u0002T#\nQ\u0001JM(D_:$X\r\u001f;\t\u0011U\u0003!\u0011!Q\u0001\fY\u000b!b]9m\u0007>tG/\u001a=u!\t9&,D\u0001Y\u0015\tI\u0006\"A\u0002tc2L!a\u0017-\u0003\u0015M\u000bFjQ8oi\u0016DH\u000fC\u0003^\u0001\u0011\u0005a,\u0001\u0004=S:LGO\u0010\u000b\u0003?\u001a$R\u0001\u00192dI\u0016\u0004B!\u0019\u00010+5\t!\u0001C\u0003D9\u0002\u000fA\tC\u0003L9\u0002\u000fA\nC\u0003O9\u0002\u000fq\nC\u0003V9\u0002\u000fa\u000bC\u0003?9\u0002\u0007q\bC\u0003i\u0001\u0011\u0005\u0013.A\u0002gSR$\"!\u00066\t\u000b-<\u0007\u0019\u00017\u0002\u000f\u0011\fG/Y:fiB\u0012Q.\u001d\t\u0004/:\u0004\u0018BA8Y\u0005\u001d!\u0015\r^1tKR\u0004\"AF9\u0005\u0013IT\u0017\u0011!A\u0001\u0006\u0003\u0019(aA0%cE\u0011!\u0004\u001e\t\u00037UL!A\u001e\u000f\u0003\u0007\u0005s\u0017\u0010C\u0003y\u0001\u0019\u0005\u00110\u0001\u0006ue\u0006Lg.T8eK2$2A_A\u0004%\rYX# \u0004\u0005y\u0002\u0001!P\u0001\u0007=e\u00164\u0017N\\3nK:$h\bE\u0002\u007f\u0003\u0007i\u0011a \u0006\u0004\u0003\u0003!\u0011AB7pI\u0016d7/C\u0002\u0002\u0006}\u0014a\u0002\u0013\u001aP\u001b>$W\r\u001c)be\u0006l7\u000fC\u0003-o\u0002\u0007q\u0006C\u0004\u0002\f\u0001!\t%!\u0004\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$B!a\u0004\u0002\u001cA!\u0011\u0011CA\f\u001b\t\t\u0019BC\u0002\u0002\u0016a\u000bQ\u0001^=qKNLA!!\u0007\u0002\u0014\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005u\u0011\u0011\u0002a\u0001\u0003\u001f\taa]2iK6\f\u0007\u0006BA\u0005\u0003C\u0001B!a\t\u0002*5\u0011\u0011Q\u0005\u0006\u0004\u0003OA\u0011AC1o]>$\u0018\r^5p]&!\u00111FA\u0013\u00051!UM^3m_B,'/\u00119j\u0011\u001d\ty\u0003\u0001C!\u0003c\tAaY8qsR!\u00111GA\u001b\u001b\u0005\u0001\u0001\u0002CA\u001c\u0003[\u0001\r!!\u000f\u0002\u000b\u0015DHO]1\u0011\t\u0005m\u0012\u0011I\u0007\u0003\u0003{Q1!a\u0010\u0007\u0003\u0015\u0001\u0018M]1n\u0013\u0011\t\u0019%!\u0010\u0003\u0011A\u000b'/Y7NCBDq!a\u0012\u0001\t\u0003\nI%A\u0003xe&$X-\u0006\u0002\u0002LA\u0019A%!\u0014\n\u0007\u0005=SE\u0001\u0005N\u0019^\u0013\u0018\u000e^3sQ\u0019\t)%a\u0015\u0002ZA!\u00111EA+\u0013\u0011\t9&!\n\u0003\u000bMKgnY3\"\u0005\u0005m\u0013!B\u0019/m9\u0002\u0004bBA0\u0001\u0011%\u0011\u0011M\u0001\u0006gBd\u0017\u000e\u001e\u000b\u0007\u0003G\n\t)a#\u0011\u000bm\t)'!\u001b\n\u0007\u0005\u001dDDA\u0003BeJ\f\u0017\u0010\u0005\u0004\u0002l\u0005E\u0014QO\u0007\u0003\u0003[R!!a\u001c\u0002\u000b]\fG/\u001a:\n\t\u0005M\u0014Q\u000e\u0002\u0004\u0017\u0016L\b\u0003BA<\u0003{j!!!\u001f\u000b\t\u0005m\u0014QN\u0001\u0005MZ,7-\u0003\u0003\u0002��\u0005e$!\u0002$sC6,\u0007\u0002CAB\u0003;\u0002\r!!\"\u0002\u0005\u0019\u0014\b\u0003BA<\u0003\u000fKA!!#\u0002z\tA\u0001JM(Ge\u0006lW\r\u0003\u0004O\u0003;\u0002\ra\u0014\u0005\b\u0003\u001f\u0003a\u0011AAI\u0003=!WMZ1vYR4\u0015\u000e\\3OC6,WCAAJ!\u0011\t)*a'\u000f\u0007m\t9*C\u0002\u0002\u001ar\ta\u0001\u0015:fI\u00164\u0017\u0002BAO\u0003?\u0013aa\u0015;sS:<'bAAM9\u0001")
/* loaded from: input_file:org/apache/spark/ml/h2o/algos/H2OAlgorithm.class */
public abstract class H2OAlgorithm<P extends Model.Parameters, M extends org.apache.spark.ml.Model<M>> extends Estimator<M> implements MLWritable, H2OAlgoParams<P> {
    private final H2OContext hc;
    private Model.Parameters parameters;
    private final DoubleParam ratio;
    private final Param<String> predictionCol;
    private final StringArrayParam featuresCols;
    private final BooleanParam allStringColumnsToCategorical;

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public P parameters() {
        return (P) this.parameters;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public void parameters_$eq(P p) {
        this.parameters = p;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final DoubleParam ratio() {
        return this.ratio;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final StringArrayParam featuresCols() {
        return this.featuresCols;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final BooleanParam allStringColumnsToCategorical() {
        return this.allStringColumnsToCategorical;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public /* synthetic */ H2OAlgoParams org$apache$spark$ml$h2o$algos$params$H2OAlgoParams$$super$set(Param param, Object obj) {
        return (H2OAlgoParams) Params.class.set(this, param, obj);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final void org$apache$spark$ml$h2o$algos$params$H2OAlgoParams$_setter_$ratio_$eq(DoubleParam doubleParam) {
        this.ratio = doubleParam;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final void org$apache$spark$ml$h2o$algos$params$H2OAlgoParams$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final void org$apache$spark$ml$h2o$algos$params$H2OAlgoParams$_setter_$featuresCols_$eq(StringArrayParam stringArrayParam) {
        this.featuresCols = stringArrayParam;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final void org$apache$spark$ml$h2o$algos$params$H2OAlgoParams$_setter_$allStringColumnsToCategorical_$eq(BooleanParam booleanParam) {
        this.allStringColumnsToCategorical = booleanParam;
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public P getParams() {
        return (P) H2OAlgoParams.Cclass.getParams(this);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public void setParams(P p) {
        H2OAlgoParams.Cclass.setParams(this, p);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public void setParams(Function1<P, BoxedUnit> function1) {
        H2OAlgoParams.Cclass.setParams(this, function1);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public String doc(String str) {
        return H2OAlgoParams.Cclass.doc(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public BooleanParam booleanParam(String str) {
        return H2OAlgoParams.Cclass.booleanParam(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public IntParam intParam(String str) {
        return H2OAlgoParams.Cclass.intParam(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public LongParam longParam(String str) {
        return H2OAlgoParams.Cclass.longParam(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public DoubleParam doubleParam(String str) {
        return H2OAlgoParams.Cclass.doubleParam(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public <T> Param<T> param(String str) {
        return H2OAlgoParams.Cclass.param(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public double getTrainRatio() {
        return H2OAlgoParams.Cclass.getTrainRatio(this);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setTrainRatio(double d) {
        return H2OAlgoParams.Cclass.setTrainRatio(this, d);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public String getPredictionsCol() {
        return H2OAlgoParams.Cclass.getPredictionsCol(this);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setPredictionsCol(String str) {
        return H2OAlgoParams.Cclass.setPredictionsCol(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final String[] getFeaturesCols() {
        return H2OAlgoParams.Cclass.getFeaturesCols(this);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setFeaturesCols(String str, Seq<String> seq) {
        return H2OAlgoParams.Cclass.setFeaturesCols(this, str, seq);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setFeaturesCols(String[] strArr) {
        return H2OAlgoParams.Cclass.setFeaturesCols(this, strArr);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setFeaturesCol(String str) {
        return H2OAlgoParams.Cclass.setFeaturesCol(this, str);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public Object setAllStringColumnsToCategorical(boolean z) {
        return H2OAlgoParams.Cclass.setAllStringColumnsToCategorical(this, z);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public boolean getAllStringColumnsToCategorical() {
        return H2OAlgoParams.Cclass.getAllStringColumnsToCategorical(this);
    }

    @Override // org.apache.spark.ml.h2o.algos.params.H2OAlgoParams
    public final <T> Object set(Param<T> param, T t, Function0<BoxedUnit> function0) {
        return H2OAlgoParams.Cclass.set(this, param, t, function0);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public M fit(Dataset<?> dataset) {
        if (Predef$.MODULE$.refArrayOps(getFeaturesCols()).isEmpty()) {
            setFeaturesCols((String[]) Predef$.MODULE$.refArrayOps(dataset.columns()).filter(new H2OAlgorithm$$anonfun$fit$1(this)));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        H2OFrame asH2OFrame = this.hc.asH2OFrame(dataset.select(Predef$.MODULE$.wrapRefArray((Column[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(getFeaturesCols()).map(new H2OAlgorithm$$anonfun$1(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))).$plus$plus(Predef$.MODULE$.refArrayOps(new Column[]{functions$.MODULE$.col(getPredictionsCol())}), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))).toDF());
        if (BoxesRunTime.unboxToDouble($(ratio())) < 1.0d) {
            Key<Frame>[] split = split(asH2OFrame, this.hc);
            ((Model.Parameters) getParams())._train = split[0];
            if (split.length > 1) {
                ((Model.Parameters) getParams())._valid = split[1];
            }
        } else {
            ((Model.Parameters) getParams())._train = asH2OFrame._key;
        }
        Frame frame = ((Model.Parameters) getParams())._train.get();
        if (getAllStringColumnsToCategorical()) {
            H2OFrameSupport$.MODULE$.allStringVecToCategorical(frame);
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        DistributionFamily distributionFamily = ((Model.Parameters) getParams())._distribution;
        DistributionFamily distributionFamily2 = DistributionFamily.bernoulli;
        if (distributionFamily != null ? !distributionFamily.equals(distributionFamily2) : distributionFamily2 != null) {
            DistributionFamily distributionFamily3 = ((Model.Parameters) getParams())._distribution;
            DistributionFamily distributionFamily4 = DistributionFamily.multinomial;
            if (distributionFamily3 != null) {
            }
            DKV.put(frame);
            M trainModel = trainModel(getParams());
            ((H2OModelParams) trainModel).setFeaturesCols((String[]) $(featuresCols()));
            ((H2OModelParams) trainModel).setPredictionsCol((String) $(predictionCol()));
            return trainModel;
        }
        if (!frame.vec(getPredictionsCol()).isCategorical()) {
            frame.replace(frame.find(getPredictionsCol()), frame.vec(getPredictionsCol()).toCategoricalVec()).remove();
        }
        DKV.put(frame);
        M trainModel2 = trainModel(getParams());
        ((H2OModelParams) trainModel2).setFeaturesCols((String[]) $(featuresCols()));
        ((H2OModelParams) trainModel2).setPredictionsCol((String) $(predictionCol()));
        return trainModel2;
    }

    public abstract M trainModel(P p);

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        Predef$.MODULE$.require(Predef$.MODULE$.refArrayOps(structType.fields()).exists(new H2OAlgorithm$$anonfun$transformSchema$2(this)), new H2OAlgorithm$$anonfun$transformSchema$1(this));
        Predef$.MODULE$.require(!Predef$.MODULE$.refArrayOps(getFeaturesCols()).exists(new H2OAlgorithm$$anonfun$transformSchema$4(this)), new H2OAlgorithm$$anonfun$transformSchema$3(this));
        return structType;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public H2OAlgorithm<P, M> m7copy(ParamMap paramMap) {
        return (H2OAlgorithm) defaultCopy(paramMap);
    }

    public MLWriter write() {
        return new H2OAlgorithmWriter(this);
    }

    private Key<Frame>[] split(H2OFrame h2OFrame, H2OContext h2OContext) {
        Key<Frame>[] keyArr = {Key.make("train"), Key.make("valid")};
        FrameSplitter frameSplitter = new FrameSplitter(h2OFrame, (double[]) Array$.MODULE$.apply(Predef$.MODULE$.wrapDoubleArray(new double[]{BoxesRunTime.unboxToDouble($(ratio()))}), ClassTag$.MODULE$.Double()), keyArr, (Key) null);
        H2O.submitTask(frameSplitter);
        frameSplitter.getResult();
        return keyArr;
    }

    public abstract String defaultFileName();

    /* JADX WARN: Multi-variable type inference failed */
    public H2OAlgorithm(Option<P> option, ClassTag<P> classTag, ClassTag<M> classTag2, H2OContext h2OContext, SQLContext sQLContext) {
        this.hc = h2OContext;
        MLWritable.class.$init$(this);
        H2OAlgoParams.Cclass.$init$(this);
        if (option.isDefined()) {
            setParams((H2OAlgorithm<P, M>) option.get());
        }
    }
}
