package ai.catboost.spark.impl;

import ai.catboost.CatBoostError;
import ai.catboost.spark.CatBoostWorkersConnectionLostException;
import ai.catboost.spark.Pool;
import ai.catboost.spark.PoolFilesPaths;
import ai.catboost.spark.SparkHelpers$;
import ai.catboost.spark.WorkerInfo;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.lang.ProcessBuilder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.spark.internal.Logging;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFullModel;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.native_impl;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.concurrent.Await$;
import scala.concurrent.Future;
import scala.concurrent.duration.Duration$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BooleanRef;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.util.control.Breaks$;
import sun.net.util.IPAddressUtil;

/* compiled from: Master.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005]tA\u0002\f\u0018\u0011\u0003IrD\u0002\u0004\"/!\u0005\u0011D\t\u0005\u0006S\u0005!\ta\u000b\u0005\u0006Y\u0005!\t!\f\u0005\n\u00033\n\u0011\u0013!C\u0001\u00037B\u0011\"!\u001d\u0002#\u0003%\t!a\u001d\u0007\u000b\u0005:\u0002!G\u0018\t\u0011i1!Q1A\u0005\u0002mB\u0001B\u0011\u0004\u0003\u0002\u0003\u0006I\u0001\u0010\u0005\t\u0007\u001a\u0011)\u0019!C\u0001\t\"A\u0001K\u0002B\u0001B\u0003%Q\t\u0003\u0005R\r\t\u0015\r\u0011\"\u0001E\u0011!\u0011fA!A!\u0002\u0013)\u0005\u0002C*\u0007\u0005\u0003\u0007I\u0011\u0001+\t\u0011\u00154!\u00111A\u0005\u0002\u0019D\u0001\u0002\u001c\u0004\u0003\u0002\u0003\u0006K!\u0016\u0005\t[\u001a\u0011\t\u0019!C\u0001]\"Q\u00111\u0001\u0004\u0003\u0002\u0004%\t!!\u0002\t\u0013\u0005%aA!A!B\u0013y\u0007BB\u0015\u0007\t\u0003\tY\u0001C\u0004\u0002\u0018\u0019!I!!\u0007\t\u000f\u0005}b\u0001\"\u0001\u0002B\u0005)2)\u0019;C_>\u001cH/T1ti\u0016\u0014xK]1qa\u0016\u0014(B\u0001\r\u001a\u0003\u0011IW\u000e\u001d7\u000b\u0005iY\u0012!B:qCJ\\'B\u0001\u000f\u001e\u0003!\u0019\u0017\r\u001e2p_N$(\"\u0001\u0010\u0002\u0005\u0005L\u0007C\u0001\u0011\u0002\u001b\u00059\"!F\"bi\n{wn\u001d;NCN$XM],sCB\u0004XM]\n\u0003\u0003\r\u0002\"\u0001J\u0014\u000e\u0003\u0015R\u0011AJ\u0001\u0006g\u000e\fG.Y\u0005\u0003Q\u0015\u0012a!\u00118z%\u00164\u0017A\u0002\u001fj]&$hh\u0001\u0001\u0015\u0003}\tQ!\u00199qYf$\u0012BLA#\u0003\u001f\n)&a\u0016\u0011\u0005\u000121c\u0001\u0004$aA\u0011\u0011'O\u0007\u0002e)\u00111\u0007N\u0001\tS:$XM\u001d8bY*\u0011!$\u000e\u0006\u0003m]\na!\u00199bG\",'\"\u0001\u001d\u0002\u0007=\u0014x-\u0003\u0002;e\t9Aj\\4hS:<W#\u0001\u001f\u0011\u0005u\u0002U\"\u0001 \u000b\u0005}\"\u0014aA:rY&\u0011\u0011I\u0010\u0002\r'B\f'o[*fgNLwN\\\u0001\u0007gB\f'o\u001b\u0011\u0002C\r\fGOQ8pgRT5o\u001c8QCJ\fWn\u001d$pe6\u000b7\u000f^3s'R\u0014\u0018N\\4\u0016\u0003\u0015\u0003\"AR'\u000f\u0005\u001d[\u0005C\u0001%&\u001b\u0005I%B\u0001&+\u0003\u0019a$o\\8u}%\u0011A*J\u0001\u0007!J,G-\u001a4\n\u00059{%AB*ue&twM\u0003\u0002MK\u0005\u00113-\u0019;C_>\u001cHOS:p]B\u000b'/Y7t\r>\u0014X*Y:uKJ\u001cFO]5oO\u0002\n\u0001\u0006\u001d:fG>l\u0007/\u001e;fI>sG.\u001b8f\u0007R\u0014X*\u001a;b\t\u0006$\u0018-Q:Kg>t7\u000b\u001e:j]\u001e\f\u0011\u0006\u001d:fG>l\u0007/\u001e;fI>sG.\u001b8f\u0007R\u0014X*\u001a;b\t\u0006$\u0018-Q:Kg>t7\u000b\u001e:j]\u001e\u0004\u0013\u0001E:bm\u0016$\u0007k\\8mg\u001a+H/\u001e:f+\u0005)\u0006c\u0001,Z76\tqK\u0003\u0002YK\u0005Q1m\u001c8dkJ\u0014XM\u001c;\n\u0005i;&A\u0002$viV\u0014X\r\u0005\u0003%9z\u0013\u0017BA/&\u0005\u0019!V\u000f\u001d7feA\u0011q\fY\u0007\u00023%\u0011\u0011-\u0007\u0002\u000f!>|GNR5mKN\u0004\u0016\r\u001e5t!\r!3MX\u0005\u0003I\u0016\u0012Q!\u0011:sCf\fAc]1wK\u0012\u0004vn\u001c7t\rV$XO]3`I\u0015\fHCA4k!\t!\u0003.\u0003\u0002jK\t!QK\\5u\u0011\u001dYg\"!AA\u0002U\u000b1\u0001\u001f\u00132\u0003E\u0019\u0018M^3e!>|Gn\u001d$viV\u0014X\rI\u0001\u0012]\u0006$\u0018N^3N_\u0012,GNU3tk2$X#A8\u0011\u0005A|X\"A9\u000b\u0005I\u001c\u0018a\u00038bi&4XmX5na2T!\u0001^;\u0002\u0007M\u00148M\u0003\u0002wo\u0006!1m\u001c:f\u0015\tA\u00180\u0001\tdCR\u0014wn\\:ui)|6\u000f]1sW*\u0011!D\u001f\u0006\u00039mT!\u0001`?\u0002\re\fg\u000eZ3y\u0015\u0005q\u0018A\u0001:v\u0013\r\t\t!\u001d\u0002\u000b)\u001a+H\u000e\\'pI\u0016d\u0017!\u00068bi&4X-T8eK2\u0014Vm];mi~#S-\u001d\u000b\u0004O\u0006\u001d\u0001bB6\u0012\u0003\u0003\u0005\ra\\\u0001\u0013]\u0006$\u0018N^3N_\u0012,GNU3tk2$\b\u0005F\u0006/\u0003\u001b\ty!!\u0005\u0002\u0014\u0005U\u0001\"\u0002\u000e\u0014\u0001\u0004a\u0004\"B\"\u0014\u0001\u0004)\u0005\"B)\u0014\u0001\u0004)\u0005bB*\u0014!\u0003\u0005\r!\u0016\u0005\b[N\u0001\n\u00111\u0001p\u0003M\u0019\u0018M^3I_N$8\u000fT5tiR{g)\u001b7f)\u00159\u00171DA\u001a\u0011\u001d\ti\u0002\u0006a\u0001\u0003?\tQ\u0002[8tiN4\u0015\u000e\\3QCRD\u0007\u0003BA\u0011\u0003_i!!a\t\u000b\t\u0005\u0015\u0012qE\u0001\u0005M&dWM\u0003\u0003\u0002*\u0005-\u0012a\u00018j_*\u0011\u0011QF\u0001\u0005U\u00064\u0018-\u0003\u0003\u00022\u0005\r\"\u0001\u0002)bi\"Dq!!\u000e\u0015\u0001\u0004\t9$A\u0006x_J\\WM]:J]\u001a|\u0007\u0003\u0002\u0013d\u0003s\u00012aXA\u001e\u0013\r\ti$\u0007\u0002\u000b/>\u00148.\u001a:J]\u001a|\u0017!\u0004;sC&t7)\u00197mE\u0006\u001c7\u000eF\u0002h\u0003\u0007Bq!!\u000e\u0016\u0001\u0004\t9\u0004C\u0004\u0002H\r\u0001\r!!\u0013\u0002+A\u0014X\r\u001d:pG\u0016\u001c8/\u001a3Ue\u0006Lg\u000eU8pYB\u0019q,a\u0013\n\u0007\u00055\u0013D\u0001\u0003Q_>d\u0007bBA)\u0007\u0001\u0007\u00111K\u0001\u0016aJ,\u0007O]8dKN\u001cX\rZ#wC2\u0004vn\u001c7t!\u0011!3-!\u0013\t\u000b\r\u001b\u0001\u0019A#\t\u000bE\u001b\u0001\u0019A#\u00027\u0011bWm]:j]&$He\u001a:fCR,'\u000f\n3fM\u0006,H\u000e\u001e\u00135+\t\tiFK\u0002V\u0003?Z#!!\u0019\u0011\t\u0005\r\u0014QN\u0007\u0003\u0003KRA!a\u001a\u0002j\u0005IQO\\2iK\u000e\\W\r\u001a\u0006\u0004\u0003W*\u0013AC1o]>$\u0018\r^5p]&!\u0011qNA3\u0005E)hn\u00195fG.,GMV1sS\u0006t7-Z\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001b\u0016\u0005\u0005U$fA8\u0002`\u0001")
/* loaded from: input_file:ai/catboost/spark/impl/CatBoostMasterWrapper.class */
public class CatBoostMasterWrapper implements Logging {
    private final SparkSession spark;
    private final String catBoostJsonParamsForMasterString;
    private final String precomputedOnlineCtrMetaDataAsJsonString;
    private Future<Tuple2<PoolFilesPaths, PoolFilesPaths[]>> savedPoolsFuture;
    private TFullModel nativeModelResult;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public static CatBoostMasterWrapper apply(Pool pool, Pool[] poolArr, String str, String str2) {
        return CatBoostMasterWrapper$.MODULE$.apply(pool, poolArr, str, str2);
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    public SparkSession spark() {
        return this.spark;
    }

    public String catBoostJsonParamsForMasterString() {
        return this.catBoostJsonParamsForMasterString;
    }

    public String precomputedOnlineCtrMetaDataAsJsonString() {
        return this.precomputedOnlineCtrMetaDataAsJsonString;
    }

    public Future<Tuple2<PoolFilesPaths, PoolFilesPaths[]>> savedPoolsFuture() {
        return this.savedPoolsFuture;
    }

    public void savedPoolsFuture_$eq(Future<Tuple2<PoolFilesPaths, PoolFilesPaths[]>> future) {
        this.savedPoolsFuture = future;
    }

    public TFullModel nativeModelResult() {
        return this.nativeModelResult;
    }

    public void nativeModelResult_$eq(TFullModel tFullModel) {
        this.nativeModelResult = tFullModel;
    }

    private void saveHostsListToFile(Path path, WorkerInfo[] workerInfoArr) {
        PrintWriter printWriter = new PrintWriter(path.toFile());
        try {
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(workerInfoArr)).foreach(workerInfo -> {
                $anonfun$saveHostsListToFile$1(printWriter, workerInfo);
                return BoxedUnit.UNIT;
            });
        } finally {
            printWriter.close();
        }
    }

    public void trainCallback(WorkerInfo[] workerInfoArr) {
        if (nativeModelResult() != null) {
            throw new CatBoostError("[Internal error] trainCallback is called again despite nativeModelResult already assigned");
        }
        Path createTempDirectory = Files.createTempDirectory("catboost_train", new FileAttribute[0]);
        Path resolve = createTempDirectory.resolve("worker_hosts.txt");
        saveHostsListToFile(resolve, workerInfoArr);
        Path resolve2 = createTempDirectory.resolve("result_model.cbm");
        Path resolve3 = createTempDirectory.resolve("json_params");
        Files.write(resolve3, catBoostJsonParamsForMasterString().getBytes(StandardCharsets.UTF_8), new OpenOption[0]);
        Path path = null;
        if (precomputedOnlineCtrMetaDataAsJsonString() != null) {
            path = createTempDirectory.resolve("precomputed_online_ctr_metadata");
            Files.write(path, precomputedOnlineCtrMetaDataAsJsonString().getBytes(StandardCharsets.UTF_8), new OpenOption[0]);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        ArrayBuffer apply = ArrayBuffer$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"--node-type", "Master", "--thread-count", BoxesRunTime.boxToInteger(SparkHelpers$.MODULE$.getThreadCountForDriver(spark())).toString(), "--params-file", resolve3.toString(), "--file-with-hosts", resolve.toString(), "--hosts-already-contain-loaded-data", "--has-time", "--max-ctr-complexity", "1", "--final-ctr-computation-mode", "Skip", "--model-file", resolve2.toString()}));
        Option<Object> driverNativeMemoryLimit = SparkHelpers$.MODULE$.getDriverNativeMemoryLimit(spark());
        if (driverNativeMemoryLimit.isDefined()) {
            apply.$plus$eq("--used-ram-limit", driverNativeMemoryLimit.get().toString(), Predef$.MODULE$.wrapRefArray(new String[0]));
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        if (precomputedOnlineCtrMetaDataAsJsonString() != null) {
            apply.$plus$eq("--precomputed-data-meta", path.toString(), Predef$.MODULE$.wrapRefArray(new String[0]));
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        log().info("Wait until Dataset data parts are ready.");
        Tuple2 tuple2 = (Tuple2) Await$.MODULE$.result(savedPoolsFuture(), Duration$.MODULE$.Inf());
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((PoolFilesPaths) tuple2._1(), (PoolFilesPaths[]) tuple2._2());
        PoolFilesPaths poolFilesPaths = (PoolFilesPaths) tuple22._1();
        PoolFilesPaths[] poolFilesPathsArr = (PoolFilesPaths[]) tuple22._2();
        log().info("Dataset data parts are ready. Start CatBoost Master process.");
        apply.$plus$eq("--learn-set", new StringBuilder(30).append("spark-quantized://master-part:").append(poolFilesPaths.mainData().toString()).toString(), Predef$.MODULE$.wrapRefArray(new String[0]));
        if (poolFilesPaths.pairsData().isDefined()) {
            apply.$plus$eq("--learn-pairs", new StringBuilder(23).append("dsv-grouped-with-idx://").append(poolFilesPaths.pairsData().get().toString()).toString(), Predef$.MODULE$.wrapRefArray(new String[0]));
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolFilesPathsArr)).isEmpty()) {
            BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
        } else {
            apply.$plus$eq("--test-set", new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolFilesPathsArr)).map(poolFilesPaths2 -> {
                return new StringBuilder(30).append("spark-quantized://master-part:").append(poolFilesPaths2.mainData()).toString();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).mkString(","), Predef$.MODULE$.wrapRefArray(new String[0]));
            if (poolFilesPaths.pairsData().isDefined()) {
                apply.$plus$eq("--test-pairs", new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolFilesPathsArr)).map(poolFilesPaths3 -> {
                    return new StringBuilder(23).append("dsv-grouped-with-idx://").append(poolFilesPaths3.pairsData().get().toString()).toString();
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).mkString(","), Predef$.MODULE$.wrapRefArray(new String[0]));
            } else {
                BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
            }
            if (precomputedOnlineCtrMetaDataAsJsonString() != null) {
                apply.$plus$eq("--test-precomputed-set", new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(poolFilesPathsArr)).map(poolFilesPaths4 -> {
                    return new StringBuilder(30).append("spark-quantized://master-part:").append(poolFilesPaths4.estimatedCtrData().get().toString()).toString();
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).mkString(","), Predef$.MODULE$.wrapRefArray(new String[0]));
            } else {
                BoxedUnit boxedUnit7 = BoxedUnit.UNIT;
            }
        }
        Process apply2 = RunClassInNewProcess$.MODULE$.apply(MasterApp$.MODULE$.getClass(), RunClassInNewProcess$.MODULE$.apply$default$2(), new Some<>(apply.toArray(ClassTag$.MODULE$.apply(String.class))), false, RunClassInNewProcess$.MODULE$.apply$default$5(), new Some<>(ProcessBuilder.Redirect.INHERIT), new Some<>(ProcessBuilder.Redirect.PIPE));
        Pattern compile = Pattern.compile("^FAIL.*(got unexpected network error, no retries rest|reply isn't OK)$");
        BooleanRef create = BooleanRef.create(false);
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(apply2.getErrorStream()));
        try {
            Breaks$.MODULE$.breakable(() -> {
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        throw Breaks$.MODULE$.break();
                    }
                    System.err.println(new StringBuilder(18).append("[CatBoost Master] ").append(readLine).toString());
                    if (compile.matcher(readLine).matches()) {
                        create.elem = true;
                    }
                }
            });
            bufferedReader.close();
            int waitFor = apply2.waitFor();
            if (waitFor != 0) {
                if (!create.elem) {
                    throw new CatBoostError(new StringBuilder(49).append("CatBoost Master process failed: exited with code ").append(waitFor).toString());
                }
                throw new CatBoostWorkersConnectionLostException("");
            }
            log().info("CatBoost Master process finished successfully.");
            log().info("Trained model: start loading");
            nativeModelResult_$eq(native_impl.ReadModelWrapper(resolve2.toString()));
            log().info("Trained model: finish loading");
            FileUtils.deleteDirectory(createTempDirectory.toFile());
        } catch (Throwable th) {
            bufferedReader.close();
            throw th;
        }
    }

    public static final /* synthetic */ void $anonfun$saveHostsListToFile$1(PrintWriter printWriter, WorkerInfo workerInfo) {
        if (workerInfo.partitionSize() > 0) {
            if (IPAddressUtil.isIPv6LiteralAddress(workerInfo.host())) {
                printWriter.println(new StringBuilder(3).append("[").append(workerInfo.host()).append("]:").append(workerInfo.port()).toString());
            } else {
                printWriter.println(new StringBuilder(1).append(workerInfo.host()).append(":").append(workerInfo.port()).toString());
            }
        }
    }

    public CatBoostMasterWrapper(SparkSession sparkSession, String str, String str2, Future<Tuple2<PoolFilesPaths, PoolFilesPaths[]>> future, TFullModel tFullModel) {
        this.spark = sparkSession;
        this.catBoostJsonParamsForMasterString = str;
        this.precomputedOnlineCtrMetaDataAsJsonString = str2;
        this.savedPoolsFuture = future;
        this.nativeModelResult = tFullModel;
        Logging.$init$(this);
    }
}
