package ai.catboost.spark;

import ai.catboost.spark.impl.Master;
import ai.catboost.spark.impl.Master$;
import ai.catboost.spark.impl.Workers;
import ai.catboost.spark.params.DatasetParamsTrait;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.QuantizationParams;
import ai.catboost.spark.params.TrainingParamsTrait;
import java.time.Duration;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.json4s.JsonAST;
import org.json4s.jackson.JsonMethods$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFullModel;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple3;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: CatBoostPredictor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055c!C\u0004\t!\u0003\r\taDA \u0011\u0015\u0001\u0005\u0001\"\u0001B\u0011\u0015)\u0005\u0001\"\u0005G\u0011\u0015)\u0007A\"\u0005g\u0011\u0015Y\b\u0001\"\u0015}\u0011\u001d\tY\u0002\u0001C\u0001\u0003;A\u0011\"a\n\u0001#\u0003%\t!!\u000b\u0003-\r\u000bGOQ8pgR\u0004&/\u001a3jGR|'\u000f\u0016:bSRT!!\u0003\u0006\u0002\u000bM\u0004\u0018M]6\u000b\u0005-a\u0011\u0001C2bi\n|wn\u001d;\u000b\u00035\t!!Y5\u0004\u0001U\u0019\u0001\u0003\n\u0018\u0014\t\u0001\tBG\u000f\t\u0006%ia\"%L\u0007\u0002')\u0011A#F\u0001\u0003[2T!!\u0003\f\u000b\u0005]A\u0012AB1qC\u000eDWMC\u0001\u001a\u0003\ry'oZ\u0005\u00037M\u0011\u0011\u0002\u0015:fI&\u001cGo\u001c:\u0011\u0005u\u0001S\"\u0001\u0010\u000b\u0005}\u0019\u0012A\u00027j]\u0006dw-\u0003\u0002\"=\t1a+Z2u_J\u0004\"a\t\u0013\r\u0001\u0011)Q\u0005\u0001b\u0001M\t9A*Z1s]\u0016\u0014\u0018CA\u0014\u0012!\tA3&D\u0001*\u0015\u0005Q\u0013!B:dC2\f\u0017B\u0001\u0017*\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"a\t\u0018\u0005\u000b=\u0002!\u0019\u0001\u0019\u0003\u000b5{G-\u001a7\u0012\u0005\u001d\n\u0004\u0003\u0002\n395J!aM\n\u0003\u001fA\u0013X\rZ5di&|g.T8eK2\u0004\"!\u000e\u001d\u000e\u0003YR!a\u000e\u0005\u0002\rA\f'/Y7t\u0013\tIdG\u0001\nECR\f7/\u001a;QCJ\fWn\u001d+sC&$\bCA\u001e?\u001b\u0005a$BA\u001f\u0014\u0003\u0011)H/\u001b7\n\u0005}b$!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7o\u0016:ji\u0006\u0014G.Z\u0001\u0007I%t\u0017\u000e\u001e\u0013\u0015\u0003\t\u0003\"\u0001K\"\n\u0005\u0011K#\u0001B+oSR\f\u0001\u0004\u001d:faJ|7-Z:t\u0005\u00164wN]3Ue\u0006Lg.\u001b8h)\r9\u0015m\u0019\t\u0006Q!Se*U\u0005\u0003\u0013&\u0012a\u0001V;qY\u0016\u001c\u0004CA&M\u001b\u0005A\u0011BA'\t\u0005\u0011\u0001vn\u001c7\u0011\u0007!z%*\u0003\u0002QS\t)\u0011I\u001d:bsB\u0011!K\u0018\b\u0003'ns!\u0001V-\u000f\u0005UCV\"\u0001,\u000b\u0005]s\u0011A\u0002\u001fs_>$h(C\u0001\u001a\u0013\tQ\u0006$\u0001\u0004kg>tGg]\u0005\u00039v\u000bq\u0001]1dW\u0006<WM\u0003\u0002[1%\u0011q\f\u0019\u0002\b\u0015>\u0013'.Z2u\u0015\taV\fC\u0003c\u0005\u0001\u0007!*\u0001\nrk\u0006tG/\u001b>fIR\u0013\u0018-\u001b8Q_>d\u0007\"\u00023\u0003\u0001\u0004q\u0015AE9vC:$\u0018N_3e\u000bZ\fG\u000eU8pYN\f1b\u0019:fCR,Wj\u001c3fYR\u0011Qf\u001a\u0005\u0006Q\u000e\u0001\r![\u0001\nMVdG.T8eK2\u0004\"A[=\u000e\u0003-T!\u0001\\7\u0002\u00179\fG/\u001b<f?&l\u0007\u000f\u001c\u0006\u0003]>\f1a\u001d:d\u0015\t\u0001\u0018/\u0001\u0003d_J,'B\u0001:t\u0003A\u0019\u0017\r\u001e2p_N$HG[0ta\u0006\u00148N\u0003\u0002\ni*\u00111\"\u001e\u0006\u0003m^\fa!_1oI\u0016D(\"\u0001=\u0002\u0005I,\u0018B\u0001>l\u0005)!f)\u001e7m\u001b>$W\r\\\u0001\u0006iJ\f\u0017N\u001c\u000b\u0003[uDQA \u0003A\u0002}\fq\u0001Z1uCN,G\u000f\r\u0003\u0002\u0002\u0005=\u0001CBA\u0002\u0003\u0013\ti!\u0004\u0002\u0002\u0006)\u0019\u0011qA\u000b\u0002\u0007M\fH.\u0003\u0003\u0002\f\u0005\u0015!a\u0002#bi\u0006\u001cX\r\u001e\t\u0004G\u0005=AaCA\t{\u0006\u0005\t\u0011!B\u0001\u0003'\u00111a\u0018\u00132#\r9\u0013Q\u0003\t\u0004Q\u0005]\u0011bAA\rS\t\u0019\u0011I\\=\u0002\u0007\u0019LG\u000fF\u0003.\u0003?\t\u0019\u0003\u0003\u0004\u0002\"\u0015\u0001\rAS\u0001\niJ\f\u0017N\u001c)p_2D\u0001\"!\n\u0006!\u0003\u0005\rAT\u0001\nKZ\fG\u000eU8pYN\fQBZ5uI\u0011,g-Y;mi\u0012\u0012TCAA\u0016U\rq\u0015QF\u0016\u0003\u0003_\u0001B!!\r\u0002<5\u0011\u00111\u0007\u0006\u0005\u0003k\t9$A\u0005v]\u000eDWmY6fI*\u0019\u0011\u0011H\u0015\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002>\u0005M\"!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dKJ1\u0011\u0011IA#\u0003\u000f2a!a\u0011\u0001\u0001\u0005}\"\u0001\u0004\u001fsK\u001aLg.Z7f]Rt\u0004\u0003B&\u0001E5\u00022!NA%\u0013\r\tYE\u000e\u0002\u0014)J\f\u0017N\\5oOB\u000b'/Y7t)J\f\u0017\u000e\u001e")
/* loaded from: input_file:ai/catboost/spark/CatBoostPredictorTrait.class */
public interface CatBoostPredictorTrait<Learner extends Predictor<Vector, Learner, Model>, Model extends PredictionModel<Vector, Model>> extends DatasetParamsTrait, DefaultParamsWritable {
    default Tuple3<Pool, Pool[], JsonAST.JObject> preprocessBeforeTraining(Pool pool, Pool[] poolArr) {
        return new Tuple3<>(pool, poolArr, Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams(this, Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams$default$2()));
    }

    Model createModel(TFullModel tFullModel);

    default Model train(Dataset<?> dataset) {
        Pool pool = new Pool(dataset);
        copyValues(pool, copyValues$default$2());
        return fit(pool, fit$default$2());
    }

    default Model fit(Pool pool, Pool[] poolArr) {
        Pool repartition;
        Helpers$.MODULE$.checkParamsCompatibility(getClass().getName(), this, "trainPool", pool);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), poolArr.length).foreach$mVc$sp(i -> {
            Helpers$.MODULE$.checkParamsCompatibility(this.getClass().getName(), this, new StringBuilder(10).append("evalPool #").append(i).toString(), poolArr[i]);
        });
        SparkSession sparkSession = pool.data().sparkSession();
        int unboxToInt = BoxesRunTime.unboxToInt(get(((TrainingParamsTrait) this).sparkPartitionCount()).getOrElse(() -> {
            return SparkHelpers$.MODULE$.getWorkerCount(sparkSession);
        }));
        if (pool.isQuantized()) {
            repartition = pool;
        } else {
            QuantizationParams quantizationParams = new QuantizationParams();
            copyValues(quantizationParams, copyValues$default$2());
            Pool quantize = pool.quantize(quantizationParams);
            repartition = quantize.repartition(unboxToInt, quantize.repartition$default$2());
        }
        Pool pool2 = repartition;
        Tuple3<Pool, Pool[], JsonAST.JObject> preprocessBeforeTraining = preprocessBeforeTraining(pool2, (Pool[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolArr)).map(pool3 -> {
            return pool3.isQuantized() ? pool3 : pool3.quantize(pool2.quantizedFeaturesInfo());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Pool.class))));
        if (preprocessBeforeTraining == null) {
            throw new MatchError(preprocessBeforeTraining);
        }
        Tuple3 tuple3 = new Tuple3((Pool) preprocessBeforeTraining._1(), (Pool[]) preprocessBeforeTraining._2(), (JsonAST.JObject) preprocessBeforeTraining._3());
        Pool pool4 = (Pool) tuple3._1();
        Pool[] poolArr2 = (Pool[]) tuple3._2();
        JsonAST.JObject jObject = (JsonAST.JObject) tuple3._3();
        Master apply = Master$.MODULE$.apply(pool4, poolArr2, JsonMethods$.MODULE$.compact(jObject));
        TrainingDriver trainingDriver = new TrainingDriver(0, unboxToInt, workerInfoArr -> {
            apply.trainCallback(workerInfoArr);
            return BoxedUnit.UNIT;
        }, (Duration) getOrDefault(((TrainingParamsTrait) this).workerInitializationTimeout()));
        int listeningPort = trainingDriver.getListeningPort();
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(Executors.newFixedThreadPool(2));
        Future<BoxedUnit> submit = executorCompletionService.submit(trainingDriver, BoxedUnit.UNIT);
        Future<BoxedUnit> submit2 = executorCompletionService.submit(new Workers(sparkSession, listeningPort, pool4, jObject), BoxedUnit.UNIT);
        Future take = executorCompletionService.take();
        if (take != null ? !take.equals(submit2) : submit2 != null) {
            ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(submit, submit2, "master");
        } else {
            ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(submit2, submit, "workers");
        }
        return createModel(apply.nativeModelResult());
    }

    default Pool[] fit$default$2() {
        return (Pool[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Pool.class));
    }

    static void $init$(CatBoostPredictorTrait catBoostPredictorTrait) {
    }
}
