package org.apache.spark.ml.tuning;

import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.incal.spark_ml.IncalSparkMLException;
import scala.Function1;
import scala.Function4;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

/* compiled from: ForwardChainingCrossValidator.scala */
/* loaded from: input_file:org/apache/spark/ml/tuning/ForwardChainingCrossValidator$$anonfun$fit$1.class */
public final class ForwardChainingCrossValidator$$anonfun$fit$1 extends AbstractFunction1<Tuple2<Tuple3<Dataset<?>, Dataset<?>, Dataset<?>>, Object>, Dataset<?>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ ForwardChainingCrossValidator $outer;
    private final Estimator est$1;
    private final Evaluator eval$1;
    private final Function1 processor$1;
    public final ParamMap[] epm$1;
    private final int numModels$1;
    private final double[] metrics$1;
    private final Function4 calcTestPredictions$1;

    public final Dataset<?> apply(Tuple2<Tuple3<Dataset<?>, Dataset<?>, Dataset<?>>, Object> tuple2) {
        if (tuple2 != null) {
            Tuple3 tuple3 = (Tuple3) tuple2._1();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            if (tuple3 != null) {
                Dataset dataset = (Dataset) tuple3._1();
                Dataset dataset2 = (Dataset) tuple3._2();
                Dataset dataset3 = (Dataset) tuple3._3();
                dataset.cache();
                dataset2.cache();
                this.$outer.logDebug(new ForwardChainingCrossValidator$$anonfun$fit$1$$anonfun$apply$2(this, _2$mcI$sp));
                Seq fit = this.est$1.fit(dataset, this.epm$1);
                dataset.unpersist();
                IntRef create = IntRef.create(0);
                while (create.elem < this.numModels$1) {
                    Dataset dataset4 = (Dataset) this.calcTestPredictions$1.apply(fit.apply(create.elem), dataset2, dataset3, this.epm$1[create.elem]);
                    if (dataset4.count() == 0) {
                        throw new IncalSparkMLException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Got no validation predictions for a forward-chaining cross validation. Perhaps the validation set with ", " rows is too short."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToLong(dataset2.count())})));
                    }
                    double evaluate = this.eval$1.evaluate((Dataset) this.processor$1.apply(dataset4));
                    this.$outer.logDebug(new ForwardChainingCrossValidator$$anonfun$fit$1$$anonfun$apply$3(this, create, evaluate));
                    int i = create.elem;
                    this.metrics$1[i] = this.metrics$1[i] + evaluate;
                    create.elem++;
                }
                return dataset2.unpersist();
            }
        }
        throw new MatchError(tuple2);
    }

    public ForwardChainingCrossValidator$$anonfun$fit$1(ForwardChainingCrossValidator forwardChainingCrossValidator, Estimator estimator, Evaluator evaluator, Function1 function1, ParamMap[] paramMapArr, int i, double[] dArr, Function4 function4) {
        if (forwardChainingCrossValidator == null) {
            throw null;
        }
        this.$outer = forwardChainingCrossValidator;
        this.est$1 = estimator;
        this.eval$1 = evaluator;
        this.processor$1 = function1;
        this.epm$1 = paramMapArr;
        this.numModels$1 = i;
        this.metrics$1 = dArr;
        this.calcTestPredictions$1 = function4;
    }
}
