package ml.dmlc.xgboost4j.java.example.flink;

import java.lang.invoke.SerializedLambda;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.flink.XGBoost;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple13;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.class */
public class DistTrainWithFlinkExample {
    static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(ExecutionEnvironment executionEnvironment, Path path, int i) throws Exception {
        DataSet zipWithIndex = DataSetUtils.zipWithIndex(parseCsv(executionEnvironment, path));
        long round = Math.round(zipWithIndex.count() * 0.01d * i);
        SingleInputUdfOperator returns = zipWithIndex.filter(tuple2 -> {
            return ((Long) tuple2.f0).longValue() < round;
        }).map(tuple22 -> {
            return (Tuple2) tuple22.f1;
        }).returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() { // from class: ml.dmlc.xgboost4j.java.example.flink.DistTrainWithFlinkExample.1
        }));
        SingleInputUdfOperator returns2 = zipWithIndex.filter(tuple23 -> {
            return ((Long) tuple23.f0).longValue() >= round;
        }).map(tuple24 -> {
            return (Vector) ((Tuple2) tuple24.f1).f0;
        }).returns(TypeInformation.of(new TypeHint<Vector>() { // from class: ml.dmlc.xgboost4j.java.example.flink.DistTrainWithFlinkExample.2
        }));
        HashMap hashMap = new HashMap(3);
        hashMap.put("eta", Double.valueOf(0.1d));
        hashMap.put("max_depth", 2);
        hashMap.put("objective", "binary:logistic");
        XGBoostModel train = XGBoost.train(returns, hashMap, 2);
        return new Tuple2<>(train, train.predict(returns2));
    }

    private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer>, Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment executionEnvironment, Path path) {
        return executionEnvironment.readCsvFile(path.toString()).ignoreFirstLine().types(Double.class, String.class, Double.class, Double.class, Double.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class).map(DistTrainWithFlinkExample::mapFunction);
    }

    private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double, Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple13) {
        DenseVector dense = Vectors.dense(new double[]{((Double) tuple13.f2).doubleValue(), ((Double) tuple13.f3).doubleValue(), ((Double) tuple13.f4).doubleValue(), ((Integer) tuple13.f5).intValue(), ((Integer) tuple13.f6).intValue(), ((Integer) tuple13.f7).intValue(), ((Integer) tuple13.f8).intValue(), ((Integer) tuple13.f9).intValue(), ((Integer) tuple13.f10).intValue(), ((Integer) tuple13.f11).intValue(), ((Integer) tuple13.f12).intValue()});
        return ((String) tuple13.f1).contains("inf") ? new Tuple2<>(dense, Double.valueOf(1.0d)) : new Tuple2<>(dense, Double.valueOf(0.0d));
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println(((DataSet) runPrediction(ExecutionEnvironment.getExecutionEnvironment(), Paths.get((String) Arrays.stream(strArr).findFirst().orElse("."), new String[0]).resolve("veterans_lung_cancer.csv"), 70).f1).collect().size());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -413399986:
                if (implMethodName.equals("lambda$runPrediction$c17bc8e1$1")) {
                    z = false;
                    break;
                }
                break;
            case -271430075:
                if (implMethodName.equals("lambda$runPrediction$d09ff938$1")) {
                    z = true;
                    break;
                }
                break;
            case -258858929:
                if (implMethodName.equals("lambda$runPrediction$7d29a6b4$1")) {
                    z = 4;
                    break;
                }
                break;
            case 184007035:
                if (implMethodName.equals("lambda$runPrediction$bcea1c91$1")) {
                    z = 3;
                    break;
                }
                break;
            case 693002580:
                if (implMethodName.equals("mapFunction")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample") && serializedLambda.getImplMethodSignature().equals("(JLorg/apache/flink/api/java/tuple/Tuple2;)Z")) {
                    long longValue = ((Long) serializedLambda.getCapturedArg(0)).longValue();
                    return tuple23 -> {
                        return ((Long) tuple23.f0).longValue() >= longValue;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Lorg/apache/flink/ml/linalg/Vector;")) {
                    return tuple24 -> {
                        return (Vector) ((Tuple2) tuple24.f1).f0;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple13;)Lorg/apache/flink/api/java/tuple/Tuple2;")) {
                    return DistTrainWithFlinkExample::mapFunction;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample") && serializedLambda.getImplMethodSignature().equals("(JLorg/apache/flink/api/java/tuple/Tuple2;)Z")) {
                    long longValue2 = ((Long) serializedLambda.getCapturedArg(0)).longValue();
                    return tuple2 -> {
                        return ((Long) tuple2.f0).longValue() < longValue2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Lorg/apache/flink/api/java/tuple/Tuple2;")) {
                    return tuple22 -> {
                        return (Tuple2) tuple22.f1;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
