package com.github.chen0040.sparkml.gp;

import com.github.chen0040.gp.commons.BasicObservation;
import com.github.chen0040.gp.lgp.LGP;
import com.github.chen0040.gp.lgp.gp.Population;
import com.github.chen0040.gp.lgp.program.Program;
import java.lang.invoke.SerializedLambda;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

/* loaded from: input_file:com/github/chen0040/sparkml/gp/SparkLGP.class */
public class SparkLGP extends LGP {
    private JavaRDD<BasicObservation> observationRdd;
    private Function<Tuple2<Program, BasicObservation>, Double> perObservationCostEvaluator;

    public void setPerObservationCostEvaluator(Function<Tuple2<Program, BasicObservation>, Double> function) {
        this.perObservationCostEvaluator = function;
    }

    public void setObservationRdd(JavaRDD<BasicObservation> javaRDD) {
        setObservationRdd(javaRDD, -1);
    }

    public void setObservationRdd(JavaRDD<BasicObservation> javaRDD, int i) {
        if (i == -1) {
            this.observationRdd = javaRDD.cache();
        } else {
            this.observationRdd = javaRDD.coalesce(i).cache();
        }
    }

    public double evaluateCost(Program program) {
        program.markStructuralIntrons(this);
        Broadcast broadcast = JavaSparkContext.fromSparkContext(this.observationRdd.context()).broadcast(program.makeEffectiveCopy());
        double doubleValue = ((Double) ((Tuple2) this.observationRdd.map(basicObservation -> {
            return new Tuple2((Program) broadcast.getValue(), basicObservation);
        }).map(this.perObservationCostEvaluator).map(d -> {
            return new Tuple2(d, 1);
        }).reduce((tuple2, tuple22) -> {
            return new Tuple2(Double.valueOf(((Double) tuple2._1()).doubleValue() + ((Double) tuple22._1()).doubleValue()), Integer.valueOf(((Integer) tuple2._2()).intValue() + ((Integer) tuple22._2()).intValue()));
        }))._1()).doubleValue() / ((Integer) r0._2()).intValue();
        broadcast.destroy();
        return doubleValue;
    }

    public Program fit(JavaRDD<BasicObservation> javaRDD) {
        setObservationRdd(javaRDD);
        long currentTimeMillis = System.currentTimeMillis();
        Population newPopulation = newPopulation();
        newPopulation.initialize();
        while (!newPopulation.isTerminated()) {
            newPopulation.evolve();
            if (getDisplayEvery() > 0 && newPopulation.getCurrentGeneration() % getDisplayEvery() == 0) {
                System.out.println("Generation: " + newPopulation.getCurrentGeneration() + " (Pop: " + newPopulation.size() + "), elapsed: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000) + " seconds");
                System.out.println("Global Cost: " + newPopulation.getGlobalBestProgram().getCost() + "\tCurrent Cost: " + newPopulation.getCostInCurrentGeneration());
            }
        }
        return newPopulation.getGlobalBestProgram();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1938490555:
                if (implMethodName.equals("lambda$evaluateCost$f5196ed5$1")) {
                    z = false;
                    break;
                }
                break;
            case -383682498:
                if (implMethodName.equals("lambda$evaluateCost$6c78d53d$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1728797294:
                if (implMethodName.equals("lambda$evaluateCost$ad458ef5$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/github/chen0040/sparkml/gp/SparkLGP") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;)Lscala/Tuple2;")) {
                    return d -> {
                        return new Tuple2(d, 1);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/github/chen0040/sparkml/gp/SparkLGP") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/broadcast/Broadcast;Lcom/github/chen0040/gp/commons/BasicObservation;)Lscala/Tuple2;")) {
                    Broadcast broadcast = (Broadcast) serializedLambda.getCapturedArg(0);
                    return basicObservation -> {
                        return new Tuple2((Program) broadcast.getValue(), basicObservation);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/github/chen0040/sparkml/gp/SparkLGP") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;Lscala/Tuple2;)Lscala/Tuple2;")) {
                    return (tuple2, tuple22) -> {
                        return new Tuple2(Double.valueOf(((Double) tuple2._1()).doubleValue() + ((Double) tuple22._1()).doubleValue()), Integer.valueOf(((Integer) tuple2._2()).intValue() + ((Integer) tuple22._2()).intValue()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
