package org.apache.spark.ml.regression;

import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.knn.KNNModelParams;
import org.apache.spark.ml.knn.i;
import org.apache.spark.ml.knn.q;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCols;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Long$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.AbstractFunction0;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;

/* compiled from: KNNRegression.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0005c\u0001B\u0001\u0003\u00015\u0011!c\u0013(O%\u0016<'/Z:tS>tWj\u001c3fY*\u00111\u0001B\u0001\u000be\u0016<'/Z:tS>t'BA\u0003\u0007\u0003\tiGN\u0003\u0002\b\u0011\u0005)1\u000f]1sW*\u0011\u0011BC\u0001\u0007CB\f7\r[3\u000b\u0003-\t1a\u001c:h\u0007\u0001\u0019R\u0001\u0001\b\u001bA!\u0002Ba\u0004\t\u001315\tA!\u0003\u0002\u0012\t\ty\u0001K]3eS\u000e$\u0018n\u001c8N_\u0012,G\u000e\u0005\u0002\u0014-5\tAC\u0003\u0002\u0016\t\u00051A.\u001b8bY\u001eL!a\u0006\u000b\u0003\rY+7\r^8s!\tI\u0002!D\u0001\u0003!\tYb$D\u0001\u001d\u0015\tiB!A\u0002l]:L!a\b\u000f\u0003\u001d-se*T8eK2\u0004\u0016M]1ngB\u0011\u0011EJ\u0007\u0002E)\u00111\u0005J\u0001\u0007g\"\f'/\u001a3\u000b\u0005\u0015\"\u0011!\u00029be\u0006l\u0017BA\u0014#\u00051A\u0015m],fS\u001eDGoQ8m!\tIC&D\u0001+\u0015\u0005Y\u0013!B:dC2\f\u0017BA\u0017+\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011!y\u0003A!b\u0001\n\u0003\u0002\u0014aA;jIV\t\u0011\u0007\u0005\u00023k9\u0011\u0011fM\u0005\u0003i)\na\u0001\u0015:fI\u00164\u0017B\u0001\u001c8\u0005\u0019\u0019FO]5oO*\u0011AG\u000b\u0005\ts\u0001\u0011\t\u0011)A\u0005c\u0005!Q/\u001b3!\u0011!Y\u0004A!b\u0001\n\u0003a\u0014a\u0002;paR\u0013X-Z\u000b\u0002{A\u0019a(Q\"\u000e\u0003}R!\u0001\u0011\u0004\u0002\u0013\t\u0014x.\u00193dCN$\u0018B\u0001\"@\u0005%\u0011%o\\1eG\u0006\u001cH\u000f\u0005\u0002\u001c\t&\u0011Q\t\b\u0002\u0005)J,W\r\u0003\u0005H\u0001\t\u0005\t\u0015!\u0003>\u0003!!x\u000e\u001d+sK\u0016\u0004\u0003\u0002C%\u0001\u0005\u000b\u0007I\u0011\u0001&\u0002\u0011M,(\r\u0016:fKN,\u0012a\u0013\t\u0004\u0019>\u001bU\"A'\u000b\u000593\u0011a\u0001:eI&\u0011\u0001+\u0014\u0002\u0004%\u0012#\u0005\u0002\u0003*\u0001\u0005\u0003\u0005\u000b\u0011B&\u0002\u0013M,(\r\u0016:fKN\u0004\u0003B\u0002+\u0001\t\u0003!Q+\u0001\u0004=S:LGO\u0010\u000b\u00051Y;\u0006\fC\u00030'\u0002\u0007\u0011\u0007C\u0003<'\u0002\u0007Q\bC\u0003J'\u0002\u00071\nC\u0003[\u0001\u0011\u00051,\u0001\u0003tKR\\EC\u0001/^\u001b\u0005\u0001\u0001\"\u00020Z\u0001\u0004y\u0016!\u0002<bYV,\u0007CA\u0015a\u0013\t\t'FA\u0002J]RDQa\u0019\u0001\u0005\u0002\u0011\fQb]3u\u0005V4g-\u001a:TSj,GC\u0001/f\u0011\u0015q&\r1\u0001g!\tIs-\u0003\u0002iU\t1Ai\\;cY\u0016DQA\u001b\u0001\u0005B-\fQ\u0002\u001e:b]N4wN]7J[BdGc\u00017\u0002\u0002A\u0011Q. \b\u0003]jt!a\u001c=\u000f\u0005A<hBA9w\u001d\t\u0011X/D\u0001t\u0015\t!H\"\u0001\u0004=e>|GOP\u0005\u0002\u0017%\u0011\u0011BC\u0005\u0003\u000f!I!!\u001f\u0004\u0002\u0007M\fH.\u0003\u0002|y\u00069\u0001/Y2lC\u001e,'BA=\u0007\u0013\tqxPA\u0005ECR\fgI]1nK*\u00111\u0010 \u0005\b\u0003\u0007I\u0007\u0019AA\u0003\u0003\u001d!\u0017\r^1tKR\u0004D!a\u0002\u0002\u0014A1\u0011\u0011BA\u0006\u0003\u001fi\u0011\u0001`\u0005\u0004\u0003\u001ba(a\u0002#bi\u0006\u001cX\r\u001e\t\u0005\u0003#\t\u0019\u0002\u0004\u0001\u0005\u0019\u0005U\u0011\u0011AA\u0001\u0002\u0003\u0015\t!a\u0006\u0003\u0007}#3'\u0005\u0003\u0002\u001a\u0005}\u0001cA\u0015\u0002\u001c%\u0019\u0011Q\u0004\u0016\u0003\u000f9{G\u000f[5oOB\u0019\u0011&!\t\n\u0007\u0005\r\"FA\u0002B]fDq!a\n\u0001\t\u0003\nI#\u0001\u0003d_BLHc\u0001\r\u0002,!A\u0011QFA\u0013\u0001\u0004\ty#A\u0003fqR\u0014\u0018\r\u0005\u0003\u00022\u0005MR\"\u0001\u0013\n\u0007\u0005UBE\u0001\u0005QCJ\fW.T1q\u0011\u001d\tI\u0004\u0001C)\u0003w\tq\u0001\u001d:fI&\u001cG\u000fF\u0002g\u0003{Aq!a\u0010\u00028\u0001\u0007!#\u0001\u0005gK\u0006$XO]3t\u0001")
/* loaded from: input_file:org/apache/spark/ml/regression/b.class */
public class b extends PredictionModel<Vector, b> implements KNNModelParams, HasWeightCol {
    private final String uid;
    private final Broadcast<q> c;
    private final RDD<q> a;

    /* renamed from: a, reason: collision with other field name */
    private final Param<String> f115a;
    private final Param<String> b;

    /* renamed from: c, reason: collision with other field name */
    private final Param<String> f116c;

    /* renamed from: a, reason: collision with other field name */
    private final IntParam f117a;

    /* renamed from: a, reason: collision with other field name */
    private final DoubleParam f118a;

    /* renamed from: b, reason: collision with other field name */
    private final DoubleParam f119b;

    /* renamed from: a, reason: collision with other field name */
    private final StringArrayParam f120a;

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$a.class */
    public final class a extends AbstractFunction0<String> implements Serializable {
        public static final long serialVersionUID = 0;

        /* renamed from: apply, reason: merged with bridge method [inline-methods] */
        public final String m631apply() {
            return "KNNModel is not designed to work with Trees that have not been cached";
        }

        public a(b bVar) {
        }
    }

    /* compiled from: KNNRegression.scala */
    /* renamed from: org.apache.spark.ml.regression.b$b, reason: collision with other inner class name */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$b.class */
    public final class C0042b extends AbstractFunction1<Row, Object> implements Serializable {
        public static final long serialVersionUID = 0;

        public final double apply(Row row) {
            return 1.0d;
        }

        public final /* synthetic */ Object apply(Object obj) {
            return BoxesRunTime.boxToDouble(apply((Row) obj));
        }

        public C0042b(b bVar) {
        }
    }

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$c.class */
    public final class c extends AbstractFunction1<Row, Object> implements Serializable {
        public static final long serialVersionUID = 0;

        public final double apply(Row row) {
            return row.getDouble(1);
        }

        public final /* synthetic */ Object apply(Object obj) {
            return BoxesRunTime.boxToDouble(apply((Row) obj));
        }

        public c(b bVar) {
        }
    }

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$d.class */
    public final class d extends AbstractFunction1<Tuple2<Object, Tuple2<Row, Object>[]>, Tuple2<Object, Object>> implements Serializable {
        public static final long serialVersionUID = 0;
        private final Function1 a;

        public final Tuple2<Object, Object> apply(Tuple2<Object, Tuple2<Row, Object>[]> tuple2) {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            long _1$mcJ$sp = tuple2._1$mcJ$sp();
            Tuple2 unzip = Predef$.MODULE$.refArrayOps((Tuple2[]) tuple2._2()).unzip(Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(Row.class), ClassTag$.MODULE$.Double());
            if (unzip == null) {
                throw new MatchError(unzip);
            }
            double d = 0.0d;
            double d2 = 0.0d;
            for (Row row : (Row[]) unzip._1()) {
                double unboxToDouble = BoxesRunTime.unboxToDouble(this.a.apply(row));
                d2 += row.getDouble(0) * unboxToDouble;
                d += unboxToDouble;
            }
            return new Tuple2.mcJD.sp(_1$mcJ$sp, d2 / d);
        }

        public d(b bVar, Function1 function1) {
            this.a = function1;
        }
    }

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$e.class */
    public final class e extends AbstractFunction1<Tuple2<Row, Object>, Object> implements Serializable {
        public static final long serialVersionUID = 0;

        public final double apply(Tuple2<Row, Object> tuple2) {
            return ((Row) tuple2._1()).getDouble(0);
        }

        public final /* synthetic */ Object apply(Object obj) {
            return BoxesRunTime.boxToDouble(apply((Tuple2<Row, Object>) obj));
        }

        public e(b bVar) {
        }
    }

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$f.class */
    public final class f extends AbstractFunction1<Tuple2<Row, Object>, Tuple2<Object, Row>> implements Serializable {
        public static final long serialVersionUID = 0;

        public final Tuple2<Object, Row> apply(Tuple2<Row, Object> tuple2) {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            return new Tuple2<>(BoxesRunTime.boxToLong(tuple2._2$mcJ$sp()), (Row) tuple2._1());
        }

        public f(b bVar) {
        }
    }

    /* compiled from: KNNRegression.scala */
    /* loaded from: input_file:org/apache/spark/ml/regression/b$g.class */
    public final class g extends AbstractFunction1<Tuple2<Object, Tuple2<Row, Option<Object>>>, Row> implements Serializable {
        public static final long serialVersionUID = 0;

        public final Row apply(Tuple2<Object, Tuple2<Row, Option<Object>>> tuple2) {
            Tuple2 tuple22;
            if (tuple2 == null || (tuple22 = (Tuple2) tuple2._2()) == null) {
                throw new MatchError(tuple2);
            }
            return Row$.MODULE$.fromSeq((Seq) ((Row) tuple22._1()).toSeq().$colon$plus(((Option) tuple22._2()).get(), Seq$.MODULE$.canBuildFrom()));
        }

        public g(b bVar) {
        }
    }

    public final Param<String> weightCol() {
        return this.f115a;
    }

    public final void org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(Param param) {
        this.f115a = param;
    }

    public final String getWeightCol() {
        return HasWeightCol.class.getWeightCol(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public Param<String> neighborsCol() {
        return this.b;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public Param<String> distanceCol() {
        return this.f116c;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public IntParam k() {
        return this.f117a;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public DoubleParam maxDistance() {
        return this.f118a;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public DoubleParam bufferSize() {
        return this.f119b;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$neighborsCol_$eq(Param param) {
        this.b = param;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$distanceCol_$eq(Param param) {
        this.f116c = param;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$k_$eq(IntParam intParam) {
        this.f117a = intParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$maxDistance_$eq(DoubleParam doubleParam) {
        this.f118a = doubleParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$bufferSize_$eq(DoubleParam doubleParam) {
        this.f119b = doubleParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public String getNeighborsCol() {
        return i.a((KNNModelParams) this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public String getDistanceCol() {
        return i.b((KNNModelParams) this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public int getK() {
        return i.m604a((KNNModelParams) this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public double getMaxDistance() {
        return i.m605a((KNNModelParams) this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public double getBufferSize() {
        return i.m606b((KNNModelParams) this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(RDD<Vector> rdd, Broadcast<q> broadcast, RDD<q> rdd2) {
        return i.a(this, rdd, broadcast, rdd2);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(Dataset<?> dataset, Broadcast<q> broadcast, RDD<q> rdd) {
        return i.a(this, dataset, broadcast, rdd);
    }

    public final StringArrayParam inputCols() {
        return this.f120a;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCols$_setter_$inputCols_$eq(StringArrayParam stringArrayParam) {
        this.f120a = stringArrayParam;
    }

    public final String[] getInputCols() {
        return HasInputCols.class.getInputCols(this);
    }

    public String uid() {
        return this.uid;
    }

    public Broadcast<q> a() {
        return this.c;
    }

    /* renamed from: a, reason: collision with other method in class */
    public RDD<q> m627a() {
        return this.a;
    }

    public b a(int i) {
        return set(k(), BoxesRunTime.boxToInteger(i));
    }

    public b a(double d2) {
        return set(bufferSize(), BoxesRunTime.boxToDouble(d2));
    }

    public Dataset<Row> transformImpl(Dataset<?> dataset) {
        return dataset.sqlContext().createDataFrame(RDD$.MODULE$.rddToPairRDDFunctions(dataset.toDF().rdd().zipWithIndex().map(new f(this), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(Row.class), Ordering$Long$.MODULE$).leftOuterJoin(transform(dataset, a(), m627a()).map(new d(this, ((String) $(weightCol())).isEmpty() ? new C0042b(this) : new c(this)), ClassTag$.MODULE$.apply(Tuple2.class))).map(new g(this), ClassTag$.MODULE$.apply(Row.class)), transformSchema(dataset.schema()));
    }

    /* renamed from: a, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public b m630copy(ParamMap paramMap) {
        return copyValues(new b(uid(), a(), m627a()), paramMap).setParent(parent());
    }

    public double predict(Vector vector) {
        SparkContext context = m627a().context();
        return BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps((double[]) Predef$.MODULE$.refArrayOps((Tuple2[]) ((Tuple2) transform(context.parallelize(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Vector[]{vector})), context.parallelize$default$2(), ClassTag$.MODULE$.apply(Vector.class)), a(), m627a()).first())._2()).map(new e(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()))).sum(Numeric$DoubleIsFractional$.MODULE$)) / r0.length;
    }

    public b(String str, Broadcast<q> broadcast, RDD<q> rdd) {
        this.uid = str;
        this.c = broadcast;
        this.a = rdd;
        HasInputCols.class.$init$(this);
        i.m607a((KNNModelParams) this);
        HasWeightCol.class.$init$(this);
        Predef$ predef$ = Predef$.MODULE$;
        StorageLevel storageLevel = rdd.getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        predef$.require(storageLevel != null ? !storageLevel.equals(NONE) : NONE != null, new a(this));
    }
}
