package ai.catboost.spark;

import ai.catboost.spark.impl.CtrFeatures$;
import ai.catboost.spark.impl.CtrsContext;
import ai.catboost.spark.impl.Master;
import ai.catboost.spark.impl.Master$;
import ai.catboost.spark.impl.Workers;
import ai.catboost.spark.params.DatasetParamsTrait;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.QuantizationParams;
import ai.catboost.spark.params.TrainingParamsTrait;
import java.time.Duration;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.json4s.JsonAST;
import org.json4s.jackson.JsonMethods$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFullModel;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.native_impl;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple3;
import scala.Tuple4;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: CatBoostPredictor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055d!\u0003\u0005\n!\u0003\r\t\u0001EA0\u0011\u0015\t\u0005\u0001\"\u0001C\u0011\u00151\u0005\u0001\"\u0005H\u0011\u0015q\u0007\u0001\"\u0005p\u0011\u0015)\bA\"\u0005w\u0011\u001d\t9\u0002\u0001C)\u00033Aq!a\u000f\u0001\t\u0003\ti\u0004C\u0005\u0002H\u0001\t\n\u0011\"\u0001\u0002J\t12)\u0019;C_>\u001cH\u000f\u0015:fI&\u001cGo\u001c:Ue\u0006LGO\u0003\u0002\u000b\u0017\u0005)1\u000f]1sW*\u0011A\"D\u0001\tG\u0006$(m\\8ti*\ta\"\u0001\u0002bS\u000e\u0001QcA\t&_M!\u0001AE\u001b<!\u0015\u00192$H\u0012/\u001b\u0005!\"BA\u000b\u0017\u0003\tiGN\u0003\u0002\u000b/)\u0011\u0001$G\u0001\u0007CB\f7\r[3\u000b\u0003i\t1a\u001c:h\u0013\taBCA\u0005Qe\u0016$\u0017n\u0019;peB\u0011a$I\u0007\u0002?)\u0011\u0001\u0005F\u0001\u0007Y&t\u0017\r\\4\n\u0005\tz\"A\u0002,fGR|'\u000f\u0005\u0002%K1\u0001A!\u0002\u0014\u0001\u0005\u00049#a\u0002'fCJtWM]\t\u0003QI\u0001\"!\u000b\u0017\u000e\u0003)R\u0011aK\u0001\u0006g\u000e\fG.Y\u0005\u0003[)\u0012qAT8uQ&tw\r\u0005\u0002%_\u0011)\u0001\u0007\u0001b\u0001c\t)Qj\u001c3fYF\u0011\u0001F\r\t\u0005'Mjb&\u0003\u00025)\ty\u0001K]3eS\u000e$\u0018n\u001c8N_\u0012,G\u000e\u0005\u00027s5\tqG\u0003\u00029\u0013\u00051\u0001/\u0019:b[NL!AO\u001c\u0003%\u0011\u000bG/Y:fiB\u000b'/Y7t)J\f\u0017\u000e\u001e\t\u0003y}j\u0011!\u0010\u0006\u0003}Q\tA!\u001e;jY&\u0011\u0001)\u0010\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:Xe&$\u0018M\u00197f\u0003\u0019!\u0013N\\5uIQ\t1\t\u0005\u0002*\t&\u0011QI\u000b\u0002\u0005+:LG/A\fbI\u0012,5\u000f^5nCR,Gm\u0011;s\r\u0016\fG/\u001e:fgR!\u0001\n\u0017.]!\u0015I\u0013jS(S\u0013\tQ%F\u0001\u0004UkBdWm\r\t\u0003\u00196k\u0011!C\u0005\u0003\u001d&\u0011A\u0001U8pYB\u0019\u0011\u0006U&\n\u0005ES#!B!se\u0006L\bCA*W\u001b\u0005!&BA+\n\u0003\u0011IW\u000e\u001d7\n\u0005]#&aC\"ueN\u001cuN\u001c;fqRDQ!\u0017\u0002A\u0002-\u000b!#];b]RL'0\u001a3Ue\u0006Lg\u000eU8pY\")1L\u0001a\u0001\u001f\u0006\u0011\u0012/^1oi&TX\rZ#wC2\u0004vn\u001c7t\u0011\u0015i&\u00011\u0001_\u0003I\u0019\u0017\r\u001e\"p_N$(j]8o!\u0006\u0014\u0018-\\:\u0011\u0005}[gB\u00011i\u001d\t\tgM\u0004\u0002cK6\t1M\u0003\u0002e\u001f\u00051AH]8pizJ\u0011AG\u0005\u0003Of\taA[:p]R\u001a\u0018BA5k\u0003\u001d\u0001\u0018mY6bO\u0016T!aZ\r\n\u00051l'a\u0002&PE*,7\r\u001e\u0006\u0003S*\f\u0001\u0004\u001d:faJ|7-Z:t\u0005\u00164wN]3Ue\u0006Lg.\u001b8h)\r\u00018\u000f\u001e\t\u0007SE\\uJ\u0018*\n\u0005IT#A\u0002+va2,G\u0007C\u0003Z\u0007\u0001\u00071\nC\u0003\\\u0007\u0001\u0007q*A\u0006de\u0016\fG/Z'pI\u0016dGC\u0001\u0018x\u0011\u0015AH\u00011\u0001z\u0003%1W\u000f\u001c7N_\u0012,G\u000eE\u0002{\u0003'i\u0011a\u001f\u0006\u0003yv\f1B\\1uSZ,w,[7qY*\u0011ap`\u0001\u0004gJ\u001c'\u0002BA\u0001\u0003\u0007\tAaY8sK*!\u0011QAA\u0004\u0003A\u0019\u0017\r\u001e2p_N$HG[0ta\u0006\u00148NC\u0002\u000b\u0003\u0013Q1\u0001DA\u0006\u0015\u0011\ti!a\u0004\u0002\re\fg\u000eZ3y\u0015\t\t\t\"\u0001\u0002sk&\u0019\u0011QC>\u0003\u0015Q3U\u000f\u001c7N_\u0012,G.A\u0003ue\u0006Lg\u000eF\u0002/\u00037Aq!!\b\u0006\u0001\u0004\ty\"A\u0004eCR\f7/\u001a;1\t\u0005\u0005\u0012q\u0006\t\u0007\u0003G\tI#!\f\u000e\u0005\u0005\u0015\"bAA\u0014-\u0005\u00191/\u001d7\n\t\u0005-\u0012Q\u0005\u0002\b\t\u0006$\u0018m]3u!\r!\u0013q\u0006\u0003\r\u0003c\tY\"!A\u0001\u0002\u000b\u0005\u00111\u0007\u0002\u0004?\u0012\n\u0014c\u0001\u0015\u00026A\u0019\u0011&a\u000e\n\u0007\u0005e\"FA\u0002B]f\f1AZ5u)\u0015q\u0013qHA\"\u0011\u0019\t\tE\u0002a\u0001\u0017\u0006IAO]1j]B{w\u000e\u001c\u0005\t\u0003\u000b2\u0001\u0013!a\u0001\u001f\u0006IQM^1m!>|Gn]\u0001\u000eM&$H\u0005Z3gCVdG\u000f\n\u001a\u0016\u0005\u0005-#fA(\u0002N-\u0012\u0011q\n\t\u0005\u0003#\nY&\u0004\u0002\u0002T)!\u0011QKA,\u0003%)hn\u00195fG.,GMC\u0002\u0002Z)\n!\"\u00198o_R\fG/[8o\u0013\u0011\ti&a\u0015\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cWM\u0005\u0004\u0002b\u0005\u0015\u0014q\r\u0004\u0007\u0003G\u0002\u0001!a\u0018\u0003\u0019q\u0012XMZ5oK6,g\u000e\u001e \u0011\t1\u00031E\f\t\u0004m\u0005%\u0014bAA6o\t\u0019BK]1j]&tw\rU1sC6\u001cHK]1ji\u0002")
/* loaded from: input_file:ai/catboost/spark/CatBoostPredictorTrait.class */
public interface CatBoostPredictorTrait<Learner extends Predictor<Vector, Learner, Model>, Model extends PredictionModel<Vector, Model>> extends DatasetParamsTrait, DefaultParamsWritable {
    default Tuple3<Pool, Pool[], CtrsContext> addEstimatedCtrFeatures(Pool pool, Pool[] poolArr, JsonAST.JObject jObject) {
        int CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn = native_impl.CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn(pool.quantizedFeaturesInfo().__deref__());
        int GetOneHotMaxSize = native_impl.GetOneHotMaxSize(CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn, pool.isDefined(pool.labelCol()), JsonMethods$.MODULE$.compact(jObject));
        return CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn > GetOneHotMaxSize ? CtrFeatures$.MODULE$.addCtrsAsEstimated(pool, poolArr, (TrainingParamsTrait) this, GetOneHotMaxSize) : new Tuple3<>(pool, poolArr, (Object) null);
    }

    default Tuple4<Pool, Pool[], JsonAST.JObject, CtrsContext> preprocessBeforeTraining(Pool pool, Pool[] poolArr) {
        JsonAST.JObject sparkMlParamsToCatBoostJsonParams = Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams(this, Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams$default$2());
        Tuple3<Pool, Pool[], CtrsContext> addEstimatedCtrFeatures = addEstimatedCtrFeatures(pool, poolArr, sparkMlParamsToCatBoostJsonParams);
        if (addEstimatedCtrFeatures == null) {
            throw new MatchError(addEstimatedCtrFeatures);
        }
        Tuple3 tuple3 = new Tuple3((Pool) addEstimatedCtrFeatures._1(), (Pool[]) addEstimatedCtrFeatures._2(), (CtrsContext) addEstimatedCtrFeatures._3());
        return new Tuple4<>((Pool) tuple3._1(), (Pool[]) tuple3._2(), sparkMlParamsToCatBoostJsonParams, (CtrsContext) tuple3._3());
    }

    Model createModel(TFullModel tFullModel);

    default Model train(Dataset<?> dataset) {
        Pool pool = new Pool(dataset);
        copyValues(pool, copyValues$default$2());
        return fit(pool, fit$default$2());
    }

    default Model fit(Pool pool, Pool[] poolArr) {
        Pool quantize;
        TFullModel nativeModelResult;
        Helpers$.MODULE$.checkParamsCompatibility(getClass().getName(), this, "trainPool", pool);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), poolArr.length).foreach$mVc$sp(i -> {
            Helpers$.MODULE$.checkParamsCompatibility(this.getClass().getName(), this, new StringBuilder(10).append("evalPool #").append(i).toString(), poolArr[i]);
        });
        SparkSession sparkSession = pool.data().sparkSession();
        if (pool.isQuantized()) {
            quantize = pool;
        } else {
            QuantizationParams quantizationParams = new QuantizationParams();
            copyValues(quantizationParams, copyValues$default$2());
            ((Logging) this).logInfo(() -> {
                return "fit. schedule quantization for train dataset";
            });
            quantize = pool.quantize(quantizationParams);
        }
        Pool pool2 = quantize;
        IntRef create = IntRef.create(0);
        Tuple4<Pool, Pool[], JsonAST.JObject, CtrsContext> preprocessBeforeTraining = preprocessBeforeTraining(pool2, (Pool[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolArr)).map(pool3 -> {
            create.elem++;
            if (pool3.isQuantized()) {
                return pool3;
            }
            ((Logging) this).logInfo(() -> {
                return new StringBuilder(45).append("fit. schedule quantization for eval dataset #").append(create.elem - 1).toString();
            });
            return pool3.quantize(pool2.quantizedFeaturesInfo());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Pool.class))));
        if (preprocessBeforeTraining == null) {
            throw new MatchError(preprocessBeforeTraining);
        }
        Tuple4 tuple4 = new Tuple4((Pool) preprocessBeforeTraining._1(), (Pool[]) preprocessBeforeTraining._2(), (JsonAST.JObject) preprocessBeforeTraining._3(), (CtrsContext) preprocessBeforeTraining._4());
        Pool pool4 = (Pool) tuple4._1();
        Pool[] poolArr2 = (Pool[]) tuple4._2();
        JsonAST.JObject jObject = (JsonAST.JObject) tuple4._3();
        CtrsContext ctrsContext = (CtrsContext) tuple4._4();
        int unboxToInt = BoxesRunTime.unboxToInt(get(((TrainingParamsTrait) this).sparkPartitionCount()).getOrElse(() -> {
            return SparkHelpers$.MODULE$.getWorkerCount(sparkSession);
        }));
        ((Logging) this).logInfo(() -> {
            return new StringBuilder(20).append("fit. partitionCount=").append(unboxToInt).toString();
        });
        String precomputedOnlineCtrMetaDataAsJsonString = ctrsContext != null ? ctrsContext.precomputedOnlineCtrMetaDataAsJsonString() : null;
        Master apply = Master$.MODULE$.apply(pool4, poolArr2, JsonMethods$.MODULE$.compact(jObject), precomputedOnlineCtrMetaDataAsJsonString);
        TrainingDriver trainingDriver = new TrainingDriver(0, unboxToInt, workerInfoArr -> {
            apply.trainCallback(workerInfoArr);
            return BoxedUnit.UNIT;
        }, (Duration) getOrDefault(((TrainingParamsTrait) this).workerInitializationTimeout()));
        int listeningPort = trainingDriver.getListeningPort();
        ((Logging) this).logInfo(() -> {
            return new StringBuilder(37).append("fit. TrainingDriver listening port = ").append(listeningPort).toString();
        });
        ((Logging) this).logInfo(() -> {
            return "fit. Training started";
        });
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(Executors.newFixedThreadPool(2));
        Future<BoxedUnit> submit = executorCompletionService.submit(trainingDriver, BoxedUnit.UNIT);
        Future<BoxedUnit> submit2 = executorCompletionService.submit(new Workers(sparkSession, unboxToInt, listeningPort, pool4, jObject, precomputedOnlineCtrMetaDataAsJsonString, apply.savedPoolsFuture()), BoxedUnit.UNIT);
        Future take = executorCompletionService.take();
        if (take != null ? !take.equals(submit2) : submit2 != null) {
            ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(submit, submit2, "master");
        } else {
            ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(submit2, submit, "workers");
        }
        ((Logging) this).logInfo(() -> {
            return "fit. Training finished";
        });
        if (ctrsContext != null) {
            ((Logging) this).logInfo(() -> {
                return "fit. Add CtrProvider to model";
            });
            nativeModelResult = CtrFeatures$.MODULE$.addCtrProviderToModel(apply.nativeModelResult(), ctrsContext, pool4, poolArr2);
        } else {
            nativeModelResult = apply.nativeModelResult();
        }
        return createModel(nativeModelResult);
    }

    default Pool[] fit$default$2() {
        return (Pool[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Pool.class));
    }

    static void $init$(CatBoostPredictorTrait catBoostPredictorTrait) {
    }
}
