package ml.dmlc.xgboost4j.java.flink;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/flink/XGBoostModel.class */
public class XGBoostModel implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger(XGBoostModel.class);
    private final Booster booster;
    private final PredictorFunction predictorFunction;

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/flink/XGBoostModel$PredictorFunction.class */
    private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
        private final Booster booster;

        public PredictorFunction(Booster booster) {
            this.booster = booster;
        }

        public void mapPartition(Iterable<Vector> iterable, Collector<Float[]> collector) throws Exception {
            Iterator it = StreamSupport.stream(iterable.spliterator(), false).map((v0) -> {
                return v0.toSparse();
            }).map(PredictorFunction::fromVector).iterator();
            if (!it.hasNext()) {
                XGBoostModel.logger.debug("Empty partition");
                return;
            }
            Stream map = Arrays.stream(this.booster.predict(new DMatrix(it, (String) null), true, 2)).map(ArrayUtils::toObject);
            collector.getClass();
            map.forEach((v1) -> {
                r1.collect(v1);
            });
        }

        private static LabeledPoint fromVector(SparseVector sparseVector) {
            int[] iArr = sparseVector.indices;
            double[] dArr = sparseVector.values;
            int length = dArr.length;
            float[] fArr = new float[length];
            for (int i = 0; i < length; i++) {
                fArr[i] = (float) dArr[i];
            }
            return new LabeledPoint(0.0f, sparseVector.size(), iArr, fArr);
        }
    }

    public XGBoostModel(Booster booster) {
        this.booster = booster;
        this.predictorFunction = new PredictorFunction(booster);
    }

    public void saveModelAsHadoopFile(String str) throws IOException, XGBoostError {
        this.booster.saveModel(FileSystem.get(new Configuration()).create(new Path(str)));
    }

    public byte[] toByteArray(String str) throws XGBoostError {
        return this.booster.toByteArray(str);
    }

    public void saveModelAsHadoopFile(String str, String str2) throws IOException, XGBoostError {
        this.booster.saveModel(FileSystem.get(new Configuration()).create(new Path(str)), str2);
    }

    public float[][] predict(DMatrix dMatrix) throws XGBoostError {
        return this.booster.predict(dMatrix, true, 0);
    }

    public DataSet<Float[]> predict(DataSet<Vector> dataSet) {
        return dataSet.mapPartition(this.predictorFunction);
    }
}
