package ai.h2o.sparkling.ml.features;

import ai.h2o.sparkling.ml.features.H2OTargetEncoderModelUtils;
import ai.h2o.sparkling.ml.models.H2OTargetEncoderBase;
import ai.h2o.sparkling.ml.models.H2OTargetEncoderModel;
import ai.h2o.sparkling.ml.params.H2OAlgoParamsHelper$;
import ai.h2o.sparkling.ml.params.H2OTargetEncoderParams;
import ai.h2o.sparkling.ml.params.NullableStringParam;
import ai.h2o.targetencoding.BlendingParams;
import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderBuilder;
import ai.h2o.targetencoding.TargetEncoderModel;
import hex.Model;
import java.io.IOException;
import org.apache.spark.h2o.H2OContext$;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.generic.GenericTraversableTemplate;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import water.fvec.Frame;
import water.fvec.H2OFrame;

/* compiled from: H2OTargetEncoder.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015f\u0001B\u0001\u0003\u00015\u0011\u0001\u0003\u0013\u001aP)\u0006\u0014x-\u001a;F]\u000e|G-\u001a:\u000b\u0005\r!\u0011\u0001\u00034fCR,(/Z:\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0005ta\u0006\u00148\u000e\\5oO*\u0011\u0011BC\u0001\u0004QJz'\"A\u0006\u0002\u0005\u0005L7\u0001A\n\u0006\u00019y\"\u0005\u000b\t\u0004\u001f]IR\"\u0001\t\u000b\u0005\u0015\t\"B\u0001\n\u0014\u0003\u0015\u0019\b/\u0019:l\u0015\t!R#\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002-\u0005\u0019qN]4\n\u0005a\u0001\"!C#ti&l\u0017\r^8s!\tQR$D\u0001\u001c\u0015\taB!\u0001\u0004n_\u0012,Gn]\u0005\u0003=m\u0011Q\u0003\u0013\u001aP)\u0006\u0014x-\u001a;F]\u000e|G-\u001a:N_\u0012,G\u000e\u0005\u0002\u001bA%\u0011\u0011e\u0007\u0002\u0015\u0011JzE+\u0019:hKR,enY8eKJ\u0014\u0015m]3\u0011\u0005\r2S\"\u0001\u0013\u000b\u0005\u0015\u0002\u0012\u0001B;uS2L!a\n\u0013\u0003+\u0011+g-Y;miB\u000b'/Y7t/JLG/\u00192mKB\u0011\u0011FK\u0007\u0002\u0005%\u00111F\u0001\u0002\u001b\u0011JzE+\u0019:hKR,enY8eKJlu\u000eZ3m+RLGn\u001d\u0005\t[\u0001\u0011)\u0019!C!]\u0005\u0019Q/\u001b3\u0016\u0003=\u0002\"\u0001\r\u001c\u000f\u0005E\"T\"\u0001\u001a\u000b\u0003M\nQa]2bY\u0006L!!\u000e\u001a\u0002\rA\u0013X\rZ3g\u0013\t9\u0004H\u0001\u0004TiJLgn\u001a\u0006\u0003kIB\u0001B\u000f\u0001\u0003\u0002\u0003\u0006IaL\u0001\u0005k&$\u0007\u0005C\u0003=\u0001\u0011\u0005Q(\u0001\u0004=S:LGO\u0010\u000b\u0003}}\u0002\"!\u000b\u0001\t\u000b5Z\u0004\u0019A\u0018\t\u000bq\u0002A\u0011A!\u0015\u0003yBQa\u0011\u0001\u0005B\u0011\u000b1AZ5u)\tIR\tC\u0003G\u0005\u0002\u0007q)A\u0004eCR\f7/\u001a;1\u0005!\u0003\u0006cA%M\u001d6\t!J\u0003\u0002L#\u0005\u00191/\u001d7\n\u00055S%a\u0002#bi\u0006\u001cX\r\u001e\t\u0003\u001fBc\u0001\u0001B\u0005R\u000b\u0006\u0005\t\u0011!B\u0001%\n\u0019q\fJ\u0019\u0012\u0005M3\u0006CA\u0019U\u0013\t)&GA\u0004O_RD\u0017N\\4\u0011\u0005E:\u0016B\u0001-3\u0005\r\te.\u001f\u0005\u00065\u0002!IaW\u0001\u0019iJ\f\u0017N\u001c+be\u001e,G/\u00128d_\u0012LgnZ'pI\u0016dGc\u0001/coB\u0011Q\fY\u0007\u0002=*\u0011q\fC\u0001\u000fi\u0006\u0014x-\u001a;f]\u000e|G-\u001b8h\u0013\t\tgL\u0001\nUCJ<W\r^#oG>$WM]'pI\u0016d\u0007\"B2Z\u0001\u0004!\u0017!\u0004;sC&t\u0017N\\4Ge\u0006lW\r\u0005\u0002fi:\u0011a-\u001d\b\u0003OBt!\u0001[8\u000f\u0005%tgB\u00016n\u001b\u0005Y'B\u00017\r\u0003\u0019a$o\\8u}%\ta#\u0003\u0002\u0015+%\u0011!cE\u0005\u0003\u0013EI!A]:\u0002\u000fA\f7m[1hK*\u0011\u0011\"E\u0005\u0003kZ\u0014QA\u0012:b[\u0016T!A]:\t\u000baL\u0006\u0019A=\u0002\u001d%<gn\u001c:fI\u000e{G.^7ogB\u0019\u0011G_\u0018\n\u0005m\u0014$!B!se\u0006L\b\"B?\u0001\t\u0003r\u0018\u0001B2paf$\"AP@\t\u000f\u0005\u0005A\u00101\u0001\u0002\u0004\u0005)Q\r\u001f;sCB!\u0011QAA\u0006\u001b\t\t9AC\u0002\u0002\nA\tQ\u0001]1sC6LA!!\u0004\u0002\b\tA\u0001+\u0019:b[6\u000b\u0007\u000fC\u0004\u0002\u0012\u0001!\t!a\u0005\u0002\u0015M,GOR8mI\u000e{G\u000e\u0006\u0003\u0002\u0016\u0005]Q\"\u0001\u0001\t\u000f\u0005e\u0011q\u0002a\u0001_\u0005)a/\u00197vK\"9\u0011Q\u0004\u0001\u0005\u0002\u0005}\u0011aC:fi2\u000b'-\u001a7D_2$B!!\u0006\u0002\"!9\u0011\u0011DA\u000e\u0001\u0004y\u0003bBA\u0013\u0001\u0011\u0005\u0011qE\u0001\rg\u0016$\u0018J\u001c9vi\u000e{Gn\u001d\u000b\u0005\u0003+\tI\u0003C\u0004\u0002,\u0005\r\u0002\u0019A=\u0002\rY\fG.^3t\u0011\u001d\ty\u0003\u0001C\u0001\u0003c\t!c]3u\u0011>dGm\\;u'R\u0014\u0018\r^3hsR!\u0011QCA\u001a\u0011\u001d\tI\"!\fA\u0002=Bq!a\u000e\u0001\t\u0003\tI$\u0001\u000btKR\u0014E.\u001a8eK\u0012\feoZ#oC\ndW\r\u001a\u000b\u0005\u0003+\tY\u0004\u0003\u0005\u0002\u001a\u0005U\u0002\u0019AA\u001f!\r\t\u0014qH\u0005\u0004\u0003\u0003\u0012$a\u0002\"p_2,\u0017M\u001c\u0005\b\u0003\u000b\u0002A\u0011AA$\u0003q\u0019X\r\u001e\"mK:$W\rZ!wO&sg\r\\3di&|g\u000eU8j]R$B!!\u0006\u0002J!A\u0011\u0011DA\"\u0001\u0004\tY\u0005E\u00022\u0003\u001bJ1!a\u00143\u0005\u0019!u.\u001e2mK\"9\u00111\u000b\u0001\u0005\u0002\u0005U\u0013AF:fi\ncWM\u001c3fI\u00063xmU7p_RD\u0017N\\4\u0015\t\u0005U\u0011q\u000b\u0005\t\u00033\t\t\u00061\u0001\u0002L!9\u00111\f\u0001\u0005\u0002\u0005u\u0013\u0001C:fi:{\u0017n]3\u0015\t\u0005U\u0011q\f\u0005\t\u00033\tI\u00061\u0001\u0002L!9\u00111\r\u0001\u0005\u0002\u0005\u0015\u0014\u0001D:fi:{\u0017n]3TK\u0016$G\u0003BA\u000b\u0003OB\u0001\"!\u0007\u0002b\u0001\u0007\u0011\u0011\u000e\t\u0004c\u0005-\u0014bAA7e\t!Aj\u001c8h\u000f\u001d\t\tH\u0001E\u0001\u0003g\n\u0001\u0003\u0013\u001aP)\u0006\u0014x-\u001a;F]\u000e|G-\u001a:\u0011\u0007%\n)H\u0002\u0004\u0002\u0005!\u0005\u0011qO\n\t\u0003k\nI(a \u0002\u0006B\u0019\u0011'a\u001f\n\u0007\u0005u$G\u0001\u0004B]f\u0014VM\u001a\t\u0005G\u0005\u0005e(C\u0002\u0002\u0004\u0012\u0012Q\u0003R3gCVdG\u000fU1sC6\u001c(+Z1eC\ndW\rE\u00022\u0003\u000fK1!!#3\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011\u001da\u0014Q\u000fC\u0001\u0003\u001b#\"!a\u001d\t\u0015\u0005E\u0015QOA\u0001\n\u0013\t\u0019*A\u0006sK\u0006$'+Z:pYZ,GCAAK!\u0011\t9*!)\u000e\u0005\u0005e%\u0002BAN\u0003;\u000bA\u0001\\1oO*\u0011\u0011qT\u0001\u0005U\u00064\u0018-\u0003\u0003\u0002$\u0006e%AB(cU\u0016\u001cG\u000f")
/* loaded from: input_file:ai/h2o/sparkling/ml/features/H2OTargetEncoder.class */
public class H2OTargetEncoder extends Estimator<H2OTargetEncoderModel> implements H2OTargetEncoderBase, DefaultParamsWritable, H2OTargetEncoderModelUtils {
    private final String uid;
    private final NullableStringParam foldCol;
    private final Param<String> labelCol;
    private final StringArrayParam inputCols;
    private final Param<String> holdoutStrategy;
    private final BooleanParam blendedAvgEnabled;
    private final DoubleParam blendedAvgInflectionPoint;
    private final DoubleParam blendedAvgSmoothing;
    private final DoubleParam noise;
    private final LongParam noiseSeed;

    public static Object load(String str) {
        return H2OTargetEncoder$.MODULE$.load(str);
    }

    public static MLReader<H2OTargetEncoder> read() {
        return H2OTargetEncoder$.MODULE$.read();
    }

    @Override // ai.h2o.sparkling.ml.features.H2OTargetEncoderModelUtils
    public void convertRelevantColumnsToCategorical(Frame frame) {
        H2OTargetEncoderModelUtils.Cclass.convertRelevantColumnsToCategorical(this, frame);
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public StructType transformSchema(StructType structType) {
        return H2OTargetEncoderBase.class.transformSchema(this, structType);
    }

    public final NullableStringParam foldCol() {
        return this.foldCol;
    }

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final StringArrayParam inputCols() {
        return this.inputCols;
    }

    public final Param<String> holdoutStrategy() {
        return this.holdoutStrategy;
    }

    public final BooleanParam blendedAvgEnabled() {
        return this.blendedAvgEnabled;
    }

    public final DoubleParam blendedAvgInflectionPoint() {
        return this.blendedAvgInflectionPoint;
    }

    public final DoubleParam blendedAvgSmoothing() {
        return this.blendedAvgSmoothing;
    }

    public final DoubleParam noise() {
        return this.noise;
    }

    public final LongParam noiseSeed() {
        return this.noiseSeed;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$foldCol_$eq(NullableStringParam nullableStringParam) {
        this.foldCol = nullableStringParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$inputCols_$eq(StringArrayParam stringArrayParam) {
        this.inputCols = stringArrayParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$holdoutStrategy_$eq(Param param) {
        this.holdoutStrategy = param;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgEnabled_$eq(BooleanParam booleanParam) {
        this.blendedAvgEnabled = booleanParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgInflectionPoint_$eq(DoubleParam doubleParam) {
        this.blendedAvgInflectionPoint = doubleParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgSmoothing_$eq(DoubleParam doubleParam) {
        this.blendedAvgSmoothing = doubleParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$noise_$eq(DoubleParam doubleParam) {
        this.noise = doubleParam;
    }

    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$noiseSeed_$eq(LongParam longParam) {
        this.noiseSeed = longParam;
    }

    public String getFoldCol() {
        return H2OTargetEncoderParams.class.getFoldCol(this);
    }

    public String getLabelCol() {
        return H2OTargetEncoderParams.class.getLabelCol(this);
    }

    public String[] getInputCols() {
        return H2OTargetEncoderParams.class.getInputCols(this);
    }

    public String[] getOutputCols() {
        return H2OTargetEncoderParams.class.getOutputCols(this);
    }

    public String getHoldoutStrategy() {
        return H2OTargetEncoderParams.class.getHoldoutStrategy(this);
    }

    public boolean getBlendedAvgEnabled() {
        return H2OTargetEncoderParams.class.getBlendedAvgEnabled(this);
    }

    public double getBlendedAvgInflectionPoint() {
        return H2OTargetEncoderParams.class.getBlendedAvgInflectionPoint(this);
    }

    public double getBlendedAvgSmoothing() {
        return H2OTargetEncoderParams.class.getBlendedAvgSmoothing(this);
    }

    public double getNoise() {
        return H2OTargetEncoderParams.class.getNoise(this);
    }

    public long getNoiseSeed() {
        return H2OTargetEncoderParams.class.getNoiseSeed(this);
    }

    public String uid() {
        return this.uid;
    }

    public H2OTargetEncoderModel fit(Dataset<?> dataset) {
        H2OFrame asH2OFrame = H2OContext$.MODULE$.getOrCreate(SparkSession$.MODULE$.builder().getOrCreate()).asH2OFrame(dataset.toDF());
        convertRelevantColumnsToCategorical(asH2OFrame);
        return copyValues((H2OTargetEncoderModel) new H2OTargetEncoderModel(uid(), trainTargetEncodingModel(asH2OFrame, (String[]) Predef$.MODULE$.refArrayOps(dataset.columns()).diff(Predef$.MODULE$.wrapRefArray((String[]) Predef$.MODULE$.refArrayOps(getInputCols()).$plus$plus(((GenericTraversableTemplate) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{getFoldCol(), getLabelCol()})).map(new H2OTargetEncoder$$anonfun$1(this), Seq$.MODULE$.canBuildFrom())).flatten(new H2OTargetEncoder$$anonfun$2(this)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))))).setParent(this), copyValues$default$2());
    }

    private TargetEncoderModel trainTargetEncodingModel(Frame frame, String[] strArr) {
        try {
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._blending = getBlendedAvgEnabled();
            targetEncoderParameters._blending_parameters = new BlendingParams(getBlendedAvgInflectionPoint(), getBlendedAvgSmoothing());
            ((Model.Parameters) targetEncoderParameters)._response_column = getLabelCol();
            ((Model.Parameters) targetEncoderParameters)._fold_column = getFoldCol();
            ((Model.Parameters) targetEncoderParameters)._ignored_columns = strArr;
            targetEncoderParameters.setTrain(frame._key);
            TargetEncoderBuilder targetEncoderBuilder = new TargetEncoderBuilder(targetEncoderParameters);
            targetEncoderBuilder.trainModel().get();
            return targetEncoderBuilder.getTargetEncoderModel();
        } catch (Throwable th) {
            if ((th instanceof IllegalStateException) && th.getMessage().contains("We do not support multi-class target case")) {
                throw new RuntimeException("The label column can not contain more than two unique values.");
            }
            throw th;
        }
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public H2OTargetEncoder m60copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public H2OTargetEncoder setFoldCol(String str) {
        return set(foldCol(), str);
    }

    public H2OTargetEncoder setLabelCol(String str) {
        return set(labelCol(), str);
    }

    public H2OTargetEncoder setInputCols(String[] strArr) {
        return set(inputCols(), strArr);
    }

    public H2OTargetEncoder setHoldoutStrategy(String str) {
        return set(holdoutStrategy(), H2OAlgoParamsHelper$.MODULE$.getValidatedEnumValue(str, ClassTag$.MODULE$.apply(TargetEncoder.DataLeakageHandlingStrategy.class)));
    }

    public H2OTargetEncoder setBlendedAvgEnabled(boolean z) {
        return set(blendedAvgEnabled(), BoxesRunTime.boxToBoolean(z));
    }

    public H2OTargetEncoder setBlendedAvgInflectionPoint(double d) {
        return set(blendedAvgInflectionPoint(), BoxesRunTime.boxToDouble(d));
    }

    public H2OTargetEncoder setBlendedAvgSmoothing(double d) {
        Predef$.MODULE$.require(d > 0.0d, new H2OTargetEncoder$$anonfun$setBlendedAvgSmoothing$1(this));
        return set(blendedAvgSmoothing(), BoxesRunTime.boxToDouble(d));
    }

    public H2OTargetEncoder setNoise(double d) {
        Predef$.MODULE$.require(d >= 0.0d, new H2OTargetEncoder$$anonfun$setNoise$1(this));
        return set(noise(), BoxesRunTime.boxToDouble(d));
    }

    public H2OTargetEncoder setNoiseSeed(long j) {
        return set(noiseSeed(), BoxesRunTime.boxToLong(j));
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ org.apache.spark.ml.Model m61fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public H2OTargetEncoder(String str) {
        this.uid = str;
        H2OTargetEncoderParams.class.$init$(this);
        H2OTargetEncoderBase.class.$init$(this);
        MLWritable.class.$init$(this);
        DefaultParamsWritable.class.$init$(this);
        H2OTargetEncoderModelUtils.Cclass.$init$(this);
    }

    public H2OTargetEncoder() {
        this(Identifiable$.MODULE$.randomUID("H2OTargetEncoder"));
    }
}
