package ml.dmlc.xgboost4j.scala.flink;

import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.LabeledPoint$;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.math.package$;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.ScalaRunTime$;

/* compiled from: XGBoostModel.scala */
@ScalaSignature(bytes = "\u0006\u0001q3AAB\u0004\u0001%!A1\u0004\u0001B\u0001B\u0003%A\u0004C\u0003!\u0001\u0011\u0005\u0011\u0005C\u0003&\u0001\u0011\u0005a\u0005C\u00038\u0001\u0011\u0005\u0001\bC\u00038\u0001\u0011\u0005QI\u0001\u0007Y\u000f\n{wn\u001d;N_\u0012,GN\u0003\u0002\t\u0013\u0005)a\r\\5oW*\u0011!bC\u0001\u0006g\u000e\fG.\u0019\u0006\u0003\u00195\t\u0011\u0002_4c_>\u001cH\u000f\u000e6\u000b\u00059y\u0011\u0001\u00023nY\u000eT\u0011\u0001E\u0001\u0003[2\u001c\u0001aE\u0002\u0001'a\u0001\"\u0001\u0006\f\u000e\u0003UQ\u0011AC\u0005\u0003/U\u0011a!\u00118z%\u00164\u0007C\u0001\u000b\u001a\u0013\tQRC\u0001\u0007TKJL\u0017\r\\5{C\ndW-A\u0004c_>\u001cH/\u001a:\u0011\u0005uqR\"A\u0005\n\u0005}I!a\u0002\"p_N$XM]\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005\t\"\u0003CA\u0012\u0001\u001b\u00059\u0001\"B\u000e\u0003\u0001\u0004a\u0012!F:bm\u0016lu\u000eZ3m\u0003ND\u0015\rZ8pa\u001aKG.\u001a\u000b\u0003O)\u0002\"\u0001\u0006\u0015\n\u0005%*\"\u0001B+oSRDQaK\u0002A\u00021\n\u0011\"\\8eK2\u0004\u0016\r\u001e5\u0011\u00055\"dB\u0001\u00183!\tyS#D\u00011\u0015\t\t\u0014#\u0001\u0004=e>|GOP\u0005\u0003gU\ta\u0001\u0015:fI\u00164\u0017BA\u001b7\u0005\u0019\u0019FO]5oO*\u00111'F\u0001\baJ,G-[2u)\tI\u0004\tE\u0002\u0015uqJ!aO\u000b\u0003\u000b\u0005\u0013(/Y=\u0011\u0007QQT\b\u0005\u0002\u0015}%\u0011q(\u0006\u0002\u0006\r2|\u0017\r\u001e\u0005\u0006\u0003\u0012\u0001\rAQ\u0001\bi\u0016\u001cHoU3u!\ti2)\u0003\u0002E\u0013\t9A)T1ue&DHC\u0001$S!\r9\u0005\u000bP\u0007\u0002\u0011*\u0011!\"\u0013\u0006\u0003\u0015.\u000b1!\u00199j\u0015\tAAJ\u0003\u0002N\u001d\u00061\u0011\r]1dQ\u0016T\u0011aT\u0001\u0004_J<\u0017BA)I\u0005\u001d!\u0015\r^1TKRDQaU\u0003A\u0002Q\u000bA\u0001Z1uCB\u0019q\tU+\u0011\u0005YSV\"A,\u000b\u0005aK\u0016\u0001B7bi\"T!\u0001E&\n\u0005m;&A\u0002,fGR|'\u000f")
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/flink/XGBoostModel.class */
public class XGBoostModel implements Serializable {
    private final Booster booster;

    public void saveModelAsHadoopFile(String str) {
        this.booster.saveModel(FileSystem.get(new Configuration()).create(new Path(str)));
    }

    public float[][] predict(DMatrix dMatrix) {
        return this.booster.predict(dMatrix, true, 0);
    }

    public DataSet<float[]> predict(DataSet<Vector> dataSet) {
        return dataSet.mapPartition(iterator -> {
            r0 = vector -> {
                Tuple2 unzip = package$.MODULE$.RichVector(vector).toSeq().unzip(Predef$.MODULE$.$conforms());
                if (unzip == null) {
                    throw new MatchError(unzip);
                }
                Tuple2 tuple2 = new Tuple2((Seq) unzip._1(), (Seq) unzip._2());
                return new LabeledPoint(0.0f, vector.size(), (int[]) ((Seq) tuple2._1()).toArray(ClassTag$.MODULE$.Int()), (float[]) ((TraversableOnce) ((Seq) tuple2._2()).map(d -> {
                    return (float) d;
                }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float()), LabeledPoint$.MODULE$.apply$default$5(), LabeledPoint$.MODULE$.apply$default$6(), LabeledPoint$.MODULE$.apply$default$7());
            };
            return Predef$.MODULE$.wrapRefArray(this.booster.predict(new DMatrix(iterator.map(vector2 -> {
                return (LabeledPoint) r3.apply(vector2);
            }), (String) null), this.booster.predict$default$2(), this.booster.predict$default$3()));
        }, PrimitiveArrayTypeInfo.getInfoFor(float[].class), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
    }

    public XGBoostModel(Booster booster) {
        this.booster = booster;
    }
}
