package com.github.chen0040.sparkml.gp;

import com.github.chen0040.gp.commons.BasicObservation;
import com.github.chen0040.gp.treegp.TreeGP;
import com.github.chen0040.gp.treegp.program.Solution;
import java.lang.invoke.SerializedLambda;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

/* loaded from: input_file:com/github/chen0040/sparkml/gp/SparkTreeGP.class */
public class SparkTreeGP extends TreeGP {
    private JavaRDD<BasicObservation> observationRdd;
    private Function<Tuple2<Solution, BasicObservation>, Double> perObservationCostEvaluator;

    public void setPerObservationCostEvaluator(Function<Tuple2<Solution, BasicObservation>, Double> function) {
        this.perObservationCostEvaluator = function;
    }

    public void setObservationRdd(JavaRDD<BasicObservation> javaRDD) {
        setObservationRdd(javaRDD, -1);
    }

    public void setObservationRdd(JavaRDD<BasicObservation> javaRDD, int i) {
        if (i == -1) {
            this.observationRdd = javaRDD.cache();
        } else {
            this.observationRdd = javaRDD.coalesce(i).cache();
        }
    }

    public int getTreeCountPerSolution() {
        return ((BasicObservation) this.observationRdd.first()).outputCount();
    }

    public double evaluateCost(Solution solution) {
        Broadcast broadcast = JavaSparkContext.fromSparkContext(this.observationRdd.context()).broadcast(solution);
        double doubleValue = ((Double) this.observationRdd.map(basicObservation -> {
            return new Tuple2((Solution) broadcast.getValue(), basicObservation);
        }).map(this.perObservationCostEvaluator).reduce((d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        })).doubleValue();
        broadcast.destroy();
        return doubleValue;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -394431661:
                if (implMethodName.equals("lambda$evaluateCost$df9c28e8$1")) {
                    z = true;
                    break;
                }
                break;
            case 1255333883:
                if (implMethodName.equals("lambda$evaluateCost$90a8f73a$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/github/chen0040/sparkml/gp/SparkTreeGP") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d, d2) -> {
                        return Double.valueOf(d.doubleValue() + d2.doubleValue());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/github/chen0040/sparkml/gp/SparkTreeGP") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/broadcast/Broadcast;Lcom/github/chen0040/gp/commons/BasicObservation;)Lscala/Tuple2;")) {
                    Broadcast broadcast = (Broadcast) serializedLambda.getCapturedArg(0);
                    return basicObservation -> {
                        return new Tuple2((Solution) broadcast.getValue(), basicObservation);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
