package org.apache.spark.ml.odkl;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.odkl.ForkedModelParams;
import org.apache.spark.ml.odkl.ModelWithSummary;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.util.DefaultParamsReader$;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.GenSeq$;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.mutable.StringBuilder;
import scala.collection.parallel.ParSeq;
import scala.collection.parallel.TaskSupport;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: ForkedEstimator.scala */
@ScalaSignature(bytes = "\u0006\u0001\t]a!B\u0001\u0003\u0003\u0003i!a\u0004$pe.,G-R:uS6\fGo\u001c:\u000b\u0005\r!\u0011\u0001B8eW2T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001U!a\"\f#\u0016'\u0011\u0001qBI\u0013\u0011\u0007A\t2#D\u0001\u0005\u0013\t\u0011BAA\u0005FgRLW.\u0019;peB\u0011A#\u0006\u0007\u0001\t\u00151\u0002A1\u0001\u0018\u0005!iu\u000eZ3m\u001fV$\u0018C\u0001\r\u001f!\tIB$D\u0001\u001b\u0015\u0005Y\u0012!B:dC2\f\u0017BA\u000f\u001b\u0005\u001dqu\u000e\u001e5j]\u001e\u00042a\b\u0011\u0014\u001b\u0005\u0011\u0011BA\u0011\u0003\u0005Aiu\u000eZ3m/&$\bnU;n[\u0006\u0014\u0018\u0010E\u0002 GMI!\u0001\n\u0002\u0003+M+X.\\1sSj\f'\r\\3FgRLW.\u0019;peB\u0011qDJ\u0005\u0003O\t\u0011\u0011CR8sW\u0016$Wj\u001c3fYB\u000b'/Y7t\u0011!I\u0003A!b\u0001\n\u0003Q\u0013A\u00028fgR,G-F\u0001,!\ry2\u0005\f\t\u0003)5\"QA\f\u0001C\u0002=\u0012q!T8eK2Le.\u0005\u0002\u0019aA\u0019q\u0004\t\u0017\t\u0011I\u0002!\u0011!Q\u0001\n-\nqA\\3ti\u0016$\u0007\u0005\u0003\u00055\u0001\t\u0015\r\u0011\"\u00116\u0003\r)\u0018\u000eZ\u000b\u0002mA\u0011qG\u000f\b\u00033aJ!!\u000f\u000e\u0002\rA\u0013X\rZ3g\u0013\tYDH\u0001\u0004TiJLgn\u001a\u0006\u0003siA\u0001B\u0010\u0001\u0003\u0002\u0003\u0006IAN\u0001\u0005k&$\u0007\u0005C\u0003A\u0001\u0011\u0005\u0011)\u0001\u0004=S:LGO\u0010\u000b\u0004\u0005*[\u0005#B\u0010\u0001Y\r\u001b\u0002C\u0001\u000bE\t\u0015)\u0005A1\u0001G\u0005-1uN]3LKf$\u0016\u0010]3\u0012\u0005a9\u0005CA\rI\u0013\tI%DA\u0002B]fDQ!K A\u0002-BQ\u0001N A\u0002YBq!\u0014\u0001C\u0002\u0013\u0015a*A\u0007ue\u0006Lg\u000eU1sC2dW\r\\\u000b\u0002\u001fB\u0019\u0001kU+\u000e\u0003ES!A\u0015\u0003\u0002\u000bA\f'/Y7\n\u0005Q\u000b&!\u0002)be\u0006l\u0007CA\rW\u0013\t9&DA\u0004C_>dW-\u00198\t\re\u0003\u0001\u0015!\u0004P\u00039!(/Y5o!\u0006\u0014\u0018\r\u001c7fY\u0002Bqa\u0017\u0001C\u0002\u0013\u0015a*\u0001\u0006dC\u000eDWMR8sWNDa!\u0018\u0001!\u0002\u001by\u0015aC2bG\",gi\u001c:lg\u0002Bqa\u0018\u0001C\u0002\u0013\u0015\u0001-A\tqCRDgi\u001c:UK6\u0004Xj\u001c3fYN,\u0012!\u0019\t\u0004!N3\u0004BB2\u0001A\u00035\u0011-\u0001\nqCRDgi\u001c:UK6\u0004Xj\u001c3fYN\u0004\u0003\"B3\u0001\t\u00031\u0017\u0001E:fiR\u0013\u0018-\u001b8QCJ\fG\u000e\\3m)\t9\u0007.D\u0001\u0001\u0011\u0015IG\r1\u0001V\u0003\u00151\u0018\r\\;f\u0011\u0015Y\u0007\u0001\"\u0001m\u00035\u0019X\r^\"bG\",gi\u001c:lgR\u0011q-\u001c\u0005\u0006S*\u0004\r!\u0016\u0005\u0006_\u0002!\t\u0001]\u0001\u0015g\u0016$\b+\u0019;i\r>\u0014H+Z7q\u001b>$W\r\\:\u0015\u0005\u001d\f\b\"B5o\u0001\u00041\u0004\"B:\u0001\r#!\u0018aC2sK\u0006$XMR8sWN$2!^A\u0015!\u00111h0a\u0001\u000f\u0005]dhB\u0001=|\u001b\u0005I(B\u0001>\r\u0003\u0019a$o\\8u}%\t1$\u0003\u0002~5\u00059\u0001/Y2lC\u001e,\u0017bA@\u0002\u0002\t\u00191+Z9\u000b\u0005uT\u0002CB\r\u0002\u0006\r\u000bI!C\u0002\u0002\bi\u0011a\u0001V;qY\u0016\u0014\u0004\u0003BA\u0006\u0003GqA!!\u0004\u0002 9!\u0011qBA\u000e\u001d\u0011\t\t\"!\u0007\u000f\t\u0005M\u0011q\u0003\b\u0004q\u0006U\u0011\"A\u0006\n\u0005%Q\u0011BA\u0004\t\u0013\r\tiBB\u0001\u0004gFd\u0017bA?\u0002\")\u0019\u0011Q\u0004\u0004\n\t\u0005\u0015\u0012q\u0005\u0002\n\t\u0006$\u0018M\u0012:b[\u0016T1!`A\u0011\u0011\u001d\tYC\u001da\u0001\u0003[\tq\u0001Z1uCN,G\u000f\r\u0003\u00020\u0005e\u0002CBA\u0019\u0003g\t9$\u0004\u0002\u0002\"%!\u0011QGA\u0011\u0005\u001d!\u0015\r^1tKR\u00042\u0001FA\u001d\t-\tY$!\u000b\u0002\u0002\u0003\u0005)\u0011\u0001$\u0003\u0007}#\u0013\u0007C\u0004\u0002@\u00011\t\"!\u0011\u0002\u00175,'oZ3N_\u0012,Gn\u001d\u000b\u0006'\u0005\r\u0013Q\n\u0005\t\u0003\u000b\ni\u00041\u0001\u0002H\u0005Q1/\u001d7D_:$X\r\u001f;\u0011\t\u0005E\u0012\u0011J\u0005\u0005\u0003\u0017\n\tC\u0001\u0006T#2\u001buN\u001c;fqRD\u0001\"a\u0014\u0002>\u0001\u0007\u0011\u0011K\u0001\u0007[>$W\r\\:\u0011\tYt\u00181\u000b\t\u00063\u0005\u00151\t\f\u0005\b\u0003/\u0002A\u0011IA-\u0003\r1\u0017\u000e\u001e\u000b\u0004'\u0005m\u0003\u0002CA\u0016\u0003+\u0002\r!!\u00181\t\u0005}\u00131\r\t\u0007\u0003c\t\u0019$!\u0019\u0011\u0007Q\t\u0019\u0007B\u0006\u0002f\u0005m\u0013\u0011!A\u0001\u0006\u00031%aA0%e!9\u0011\u0011\u000e\u0001\u0005\u0002\u0005-\u0014a\u00024ji\u001a{'o\u001b\u000b\t\u0003'\ni'!\u001d\u0002��!9\u0011qNA4\u0001\u0004Y\u0013!C3ti&l\u0017\r^8s\u0011!\t\u0019(a\u001aA\u0002\u0005U\u0014!C<i_2,G)\u0019;ba\u0011\t9(a\u001f\u0011\r\u0005E\u00121GA=!\r!\u00121\u0010\u0003\f\u0003{\n\t(!A\u0001\u0002\u000b\u0005aIA\u0002`IMB\u0001\"!!\u0002h\u0001\u0007\u00111A\u0001\fa\u0006\u0014H/[1m\t\u0006$\u0018\rC\u0004\u0002\u0006\u0002!\t%a\"\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$B!!#\u0002\u0016B!\u00111RAI\u001b\t\tiI\u0003\u0003\u0002\u0010\u0006\u0005\u0012!\u0002;za\u0016\u001c\u0018\u0002BAJ\u0003\u001b\u0013!b\u0015;sk\u000e$H+\u001f9f\u0011!\t9*a!A\u0002\u0005%\u0015AB:dQ\u0016l\u0017\r\u000b\u0003\u0002\u0004\u0006m\u0005\u0003BAO\u0003Gk!!a(\u000b\u0007\u0005\u0005f!\u0001\u0006b]:|G/\u0019;j_:LA!!*\u0002 \naA)\u001a<fY>\u0004XM]!qS\u001e9\u0011\u0011\u0016\u0002\t\u0002\u0005-\u0016a\u0004$pe.,G-R:uS6\fGo\u001c:\u0011\u0007}\tiK\u0002\u0004\u0002\u0005!\u0005\u0011qV\n\u0007\u0003[\u000b\t,a.\u0011\u0007e\t\u0019,C\u0002\u00026j\u0011a!\u00118z%\u00164\u0007cA\r\u0002:&\u0019\u00111\u0018\u000e\u0003\u0019M+'/[1mSj\f'\r\\3\t\u000f\u0001\u000bi\u000b\"\u0001\u0002@R\u0011\u00111\u0016\u0005\u000b\u0003\u0007\fi\u000b1A\u0005\n\u0005\u0015\u0017a\u0003;bg.\u001cV\u000f\u001d9peR,\"!a2\u0011\u000be\tI-!4\n\u0007\u0005-'D\u0001\u0004PaRLwN\u001c\t\u0005\u0003\u001f\fI.\u0004\u0002\u0002R*!\u00111[Ak\u0003!\u0001\u0018M]1mY\u0016d'bAAl5\u0005Q1m\u001c7mK\u000e$\u0018n\u001c8\n\t\u0005m\u0017\u0011\u001b\u0002\f)\u0006\u001c8nU;qa>\u0014H\u000f\u0003\u0006\u0002`\u00065\u0006\u0019!C\u0005\u0003C\fq\u0002^1tWN+\b\u000f]8si~#S-\u001d\u000b\u0005\u0003G\fI\u000fE\u0002\u001a\u0003KL1!a:\u001b\u0005\u0011)f.\u001b;\t\u0015\u0005-\u0018Q\\A\u0001\u0002\u0004\t9-A\u0002yIEB\u0011\"a<\u0002.\u0002\u0006K!a2\u0002\u0019Q\f7o[*vaB|'\u000f\u001e\u0011\t\u0011\u0005M\u0018Q\u0016C\u0001\u0003k\fabZ3u)\u0006\u001c8nU;qa>\u0014H/\u0006\u0002\u0002N\"A\u0011\u0011`AW\t\u0003\tY0\u0001\btKR$\u0016m]6TkB\u0004xN\u001d;\u0015\t\u0005\r\u0018Q \u0005\t\u0003\u007f\f9\u00101\u0001\u0002N\u000691/\u001e9q_J$\bB\u0003B\u0002\u0003[\u000b\t\u0011\"\u0003\u0003\u0006\u0005Y!/Z1e%\u0016\u001cx\u000e\u001c<f)\t\u00119\u0001\u0005\u0003\u0003\n\tMQB\u0001B\u0006\u0015\u0011\u0011iAa\u0004\u0002\t1\fgn\u001a\u0006\u0003\u0005#\tAA[1wC&!!Q\u0003B\u0006\u0005\u0019y%M[3di\u0002")
/* loaded from: input_file:org/apache/spark/ml/odkl/ForkedEstimator.class */
public abstract class ForkedEstimator<ModelIn extends ModelWithSummary<ModelIn>, ForeKeyType, ModelOut extends ModelWithSummary<ModelOut>> extends Estimator<ModelOut> implements SummarizableEstimator<ModelOut>, ForkedModelParams {
    private final SummarizableEstimator<ModelIn> nested;
    private final String uid;
    private final Param<Object> trainParallel;
    private final Param<Object> cacheForks;
    private final Param<String> pathForTempModels;
    private final Param<String> propagatedKeyColumn;

    public static void setTaskSupport(TaskSupport taskSupport) {
        ForkedEstimator$.MODULE$.setTaskSupport(taskSupport);
    }

    public static TaskSupport getTaskSupport() {
        return ForkedEstimator$.MODULE$.getTaskSupport();
    }

    @Override // org.apache.spark.ml.odkl.ForkedModelParams
    public final Param<String> propagatedKeyColumn() {
        return this.propagatedKeyColumn;
    }

    @Override // org.apache.spark.ml.odkl.ForkedModelParams
    public final void org$apache$spark$ml$odkl$ForkedModelParams$_setter_$propagatedKeyColumn_$eq(Param param) {
        this.propagatedKeyColumn = param;
    }

    @Override // org.apache.spark.ml.odkl.ForkedModelParams
    public ForkedModelParams setPropagatedKeyColumn(String str) {
        return ForkedModelParams.Cclass.setPropagatedKeyColumn(this, str);
    }

    @Override // org.apache.spark.ml.odkl.ForkedModelParams
    public Dataset<Row> mayBePropagateKey(Dataset<Row> dataset, Object obj) {
        return ForkedModelParams.Cclass.mayBePropagateKey(this, dataset, obj);
    }

    public SummarizableEstimator<ModelIn> nested() {
        return this.nested;
    }

    public String uid() {
        return this.uid;
    }

    public final Param<Object> trainParallel() {
        return this.trainParallel;
    }

    public final Param<Object> cacheForks() {
        return this.cacheForks;
    }

    public final Param<String> pathForTempModels() {
        return this.pathForTempModels;
    }

    public ForkedEstimator<ModelIn, ForeKeyType, ModelOut> setTrainParallel(boolean z) {
        return set(trainParallel(), BoxesRunTime.boxToBoolean(z));
    }

    public ForkedEstimator<ModelIn, ForeKeyType, ModelOut> setCacheForks(boolean z) {
        return set(cacheForks(), BoxesRunTime.boxToBoolean(z));
    }

    public ForkedEstimator<ModelIn, ForeKeyType, ModelOut> setPathForTempModels(String str) {
        return StringUtils.isNotBlank(str) ? set(pathForTempModels(), str) : clear(pathForTempModels());
    }

    public abstract Seq<Tuple2<ForeKeyType, Dataset<Row>>> createForks(Dataset<?> dataset);

    public abstract ModelOut mergeModels(SQLContext sQLContext, Seq<Tuple2<ForeKeyType, ModelIn>> seq);

    public ModelOut fit(Dataset<?> dataset) {
        ParSeq parSeq;
        ParSeq createForks = createForks(dataset);
        if (BoxesRunTime.unboxToBoolean($(cacheForks()))) {
            BoxesRunTime.boxToLong(((Dataset) ((TraversableOnce) createForks.map(new ForkedEstimator$$anonfun$fit$1(this), Seq$.MODULE$.canBuildFrom())).reduce(new ForkedEstimator$$anonfun$fit$2(this))).count());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        try {
            if (BoxesRunTime.unboxToBoolean($(trainParallel()))) {
                ParSeq par = createForks.par();
                par.tasksupport_$eq(ForkedEstimator$.MODULE$.getTaskSupport());
                parSeq = par;
            } else {
                parSeq = createForks;
            }
            ModelOut parent = ((Model) mergeModels(dataset.sqlContext(), Predef$.MODULE$.wrapRefArray((Tuple2[]) ((GenTraversableOnce) parSeq.map(new ForkedEstimator$$anonfun$1(this, dataset), GenSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple2.class))))).setParent(this);
            if (isDefined(propagatedKeyColumn()) && (parent instanceof ForkedModelParams)) {
                ((ForkedModelParams) parent).setPropagatedKeyColumn((String) $(propagatedKeyColumn()));
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            if (isDefined(pathForTempModels())) {
                BoxesRunTime.boxToBoolean(FileSystem.get(dataset.sqlContext().sparkContext().hadoopConfiguration()).deleteOnExit(new Path((String) $(pathForTempModels()))));
            } else {
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            }
            return parent;
        } finally {
            if (BoxesRunTime.unboxToBoolean($(cacheForks()))) {
                createForks.foreach(new ForkedEstimator$$anonfun$fit$3(this));
            }
        }
    }

    public Tuple2<ForeKeyType, ModelIn> fitFork(SummarizableEstimator<ModelIn> summarizableEstimator, Dataset<?> dataset, Tuple2<ForeKeyType, Dataset<Row>> tuple2) {
        logInfo(new ForkedEstimator$$anonfun$fitFork$1(this, tuple2));
        if (isDefined(pathForTempModels())) {
            String stringBuilder = new StringBuilder().append((String) $(pathForTempModels())).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"/key=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{tuple2._1().toString()}))).toString();
            if (FileSystem.get(dataset.sqlContext().sparkContext().hadoopConfiguration()).exists(new Path(stringBuilder))) {
                logInfo(new ForkedEstimator$$anonfun$fitFork$2(this, tuple2, stringBuilder));
                return new Tuple2<>(tuple2._1(), DefaultParamsReader$.MODULE$.loadParamsInstance(stringBuilder, dataset.sqlContext().sparkContext()));
            }
        }
        Tuple2<ForeKeyType, ModelIn> tuple22 = new Tuple2<>(tuple2._1(), ((Estimator) summarizableEstimator).fit((Dataset) get(propagatedKeyColumn()).map(new ForkedEstimator$$anonfun$2(this, tuple2)).getOrElse(new ForkedEstimator$$anonfun$3(this, tuple2))));
        logInfo(new ForkedEstimator$$anonfun$fitFork$3(this, tuple2));
        if (isDefined(pathForTempModels())) {
            ((ModelWithSummary) tuple22._2()).m267write().save(new StringBuilder().append((String) $(pathForTempModels())).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"/key=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{tuple2._1().toString()}))).toString());
        }
        return tuple22;
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        return nested().transformSchema(structType);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m258fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public ForkedEstimator(SummarizableEstimator<ModelIn> summarizableEstimator, String str) {
        this.nested = summarizableEstimator;
        this.uid = str;
        org$apache$spark$ml$odkl$ForkedModelParams$_setter_$propagatedKeyColumn_$eq(new Param((Identifiable) this, "propagatedKeyColumn", "If provided, value of the key the fork is created for is added to the data as a ne column with this name"));
        this.trainParallel = new Param<>(this, "trainParallel", "Whenever to train different parts in parallel");
        this.cacheForks = new Param<>(this, "cacheForks", "Useful to reduce IO when training in parallel. If set caches and materializes forks using a single job.");
        this.pathForTempModels = new Param<>(this, "pathForTempModels", "Used for incremental training. Persist models when trained and skips training if valid model found.");
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{trainParallel().$minus$greater(BoxesRunTime.boxToBoolean(false)), cacheForks().$minus$greater(BoxesRunTime.boxToBoolean(false))}));
    }
}
