package org.apache.spark.ml.odkl;

import java.io.IOException;
import odkl.analysis.spark.util.RDDOperations$;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.odkl.CRRSamplerParams;
import org.apache.spark.ml.odkl.HasGroupByColumns;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamMap$;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.Numeric$LongIsIntegral$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: CRRSampler.scala */
@ScalaSignature(bytes = "\u0006\u000194A!\u0001\u0002\u0001\u001b\t\u00192I\u0015*TC6\u0004H.\u001a:FgRLW.\u0019;pe*\u00111\u0001B\u0001\u0005_\u0012\\GN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sO\u000e\u00011\u0003\u0002\u0001\u000f-q\u00012a\u0004\t\u0013\u001b\u0005!\u0011BA\t\u0005\u0005%)5\u000f^5nCR|'\u000f\u0005\u0002\u0014)5\t!!\u0003\u0002\u0016\u0005\ty1I\u0015*TC6\u0004H.\u001a:N_\u0012,G\u000e\u0005\u0002\u001855\t\u0001D\u0003\u0002\u001a\t\u0005!Q\u000f^5m\u0013\tY\u0002DA\u000bEK\u001a\fW\u000f\u001c;QCJ\fWn],sSR\f'\r\\3\u0011\u0005Mi\u0012B\u0001\u0010\u0003\u0005A\u0019%KU*b[BdWM\u001d)be\u0006l7\u000f\u0003\u0005!\u0001\t\u0015\r\u0011\"\u0011\"\u0003\r)\u0018\u000eZ\u000b\u0002EA\u00111%\u000b\b\u0003I\u001dj\u0011!\n\u0006\u0002M\u0005)1oY1mC&\u0011\u0001&J\u0001\u0007!J,G-\u001a4\n\u0005)Z#AB*ue&twM\u0003\u0002)K!AQ\u0006\u0001B\u0001B\u0003%!%\u0001\u0003vS\u0012\u0004\u0003\"B\u0018\u0001\t\u0003\u0001\u0014A\u0002\u001fj]&$h\b\u0006\u00022eA\u00111\u0003\u0001\u0005\u0006A9\u0002\rA\t\u0005\bi\u0001\u0011\r\u0011\"\u00016\u0003I)\u0007\u0010]3di\u0016$g*^7TC6\u0004H.Z:\u0016\u0003Y\u0002\"a\u000e\u001e\u000e\u0003aR!!\u000f\u0003\u0002\u000bA\f'/Y7\n\u0005mB$\u0001C%oiB\u000b'/Y7\t\ru\u0002\u0001\u0015!\u00037\u0003M)\u0007\u0010]3di\u0016$g*^7TC6\u0004H.Z:!\u0011\u0015y\u0004\u0001\"\u0001A\u0003i\u0019X\r^#ya\u0016\u001cG/\u001a3Ok6\u0014WM](g'\u0006l\u0007\u000f\\3t)\t\t%)D\u0001\u0001\u0011\u0015\u0019e\b1\u0001E\u0003\u00151\u0018\r\\;f!\t!S)\u0003\u0002GK\t\u0019\u0011J\u001c;\t\u000b=\u0002A\u0011\u0001%\u0015\u0003EBQA\u0013\u0001\u0005B-\u000b1AZ5u)\t\u0011B\nC\u0003N\u0013\u0002\u0007a*A\u0004eCR\f7/\u001a;\u0011\u0005=\u0013V\"\u0001)\u000b\u0005E3\u0011aA:rY&\u00111\u000b\u0015\u0002\n\t\u0006$\u0018M\u0012:b[\u0016DQ!\u0016\u0001\u0005BY\u000bAaY8qsR\u0011\u0011g\u0016\u0005\u00061R\u0003\r!W\u0001\u0006Kb$(/\u0019\t\u0003oiK!a\u0017\u001d\u0003\u0011A\u000b'/Y7NCBDQ!\u0018\u0001\u0005By\u000bq\u0002\u001e:b]N4wN]7TG\",W.\u0019\u000b\u0003?\u0016\u0004\"\u0001Y2\u000e\u0003\u0005T!A\u0019)\u0002\u000bQL\b/Z:\n\u0005\u0011\f'AC*ueV\u001cG\u000fV=qK\")a\r\u0018a\u0001?\u000611o\u00195f[\u0006D#\u0001\u00185\u0011\u0005%dW\"\u00016\u000b\u0005-4\u0011AC1o]>$\u0018\r^5p]&\u0011QN\u001b\u0002\r\t\u00164X\r\\8qKJ\f\u0005/\u001b")
/* loaded from: input_file:org/apache/spark/ml/odkl/CRRSamplerEstimator.class */
public class CRRSamplerEstimator extends Estimator<CRRSamplerModel> implements DefaultParamsWritable, CRRSamplerParams {
    private final String uid;
    private final IntParam expectedNumSamples;
    private final DoubleParam groupSampleRate;
    private final DoubleParam itemSampleRate;
    private final DoubleParam rankingPower;
    private final IntParam shuffleToPartitions;
    private final Param<String> labelCol;
    private final StringArrayParam groupByColumns;
    private final Param<String> inputCol;

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam groupSampleRate() {
        return this.groupSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam itemSampleRate() {
        return this.itemSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam rankingPower() {
        return this.rankingPower;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public IntParam shuffleToPartitions() {
        return this.shuffleToPartitions;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$groupSampleRate_$eq(DoubleParam doubleParam) {
        this.groupSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$itemSampleRate_$eq(DoubleParam doubleParam) {
        this.itemSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$rankingPower_$eq(DoubleParam doubleParam) {
        this.rankingPower = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$shuffleToPartitions_$eq(IntParam intParam) {
        this.shuffleToPartitions = intParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setGroupSampleRate(double d) {
        return CRRSamplerParams.Cclass.setGroupSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setItemSampleRate(double d) {
        return CRRSamplerParams.Cclass.setItemSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setRankingPower(double d) {
        return CRRSamplerParams.Cclass.setRankingPower(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setShufflerToPartitions(int i) {
        return CRRSamplerParams.Cclass.setShufflerToPartitions(this, i);
    }

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    public final String getLabelCol() {
        return HasLabelCol.class.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final StringArrayParam groupByColumns() {
        return this.groupByColumns;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final void org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(StringArrayParam stringArrayParam) {
        this.groupByColumns = stringArrayParam;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public HasGroupByColumns setGroupByColumns(Seq<String> seq) {
        return HasGroupByColumns.Cclass.setGroupByColumns(this, seq);
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    public final String getInputCol() {
        return HasInputCol.class.getInputCol(this);
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public String uid() {
        return this.uid;
    }

    public IntParam expectedNumSamples() {
        return this.expectedNumSamples;
    }

    public CRRSamplerEstimator setExpectedNumberOfSamples(int i) {
        return set(expectedNumSamples(), BoxesRunTime.boxToInteger(i));
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public CRRSamplerModel m126fit(DataFrame dataFrame) {
        double count;
        Tuple2 tuple2;
        if (BoxesRunTime.unboxToDouble($(rankingPower())) <= 0 || !isDefined(groupByColumns()) || ((String[]) $(groupByColumns())).length <= 0) {
            count = dataFrame.count();
        } else {
            if (((String[]) $(groupByColumns())).length == 1) {
                tuple2 = new Tuple2(dataFrame, BoxesRunTime.boxToInteger(dataFrame.schema().fieldIndex((String) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).head())));
            } else {
                Column struct = functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).map(new CRRSamplerEstimator$$anonfun$13(this, dataFrame), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
                String s = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_tmpKey"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{uid()}));
                DataFrame withColumn = dataFrame.withColumn(s, struct);
                tuple2 = new Tuple2(withColumn, BoxesRunTime.boxToInteger(withColumn.schema().fieldIndex(s)));
            }
            Tuple2 tuple22 = tuple2;
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            Tuple2 tuple23 = new Tuple2((DataFrame) tuple22._1(), BoxesRunTime.boxToInteger(tuple22._2$mcI$sp()));
            DataFrame dataFrame2 = (DataFrame) tuple23._1();
            count = RDD$.MODULE$.numericRDDToDoubleRDDFunctions(RDDOperations$.MODULE$.ImplicitRDDDecorator(dataFrame2.rdd(), ClassTag$.MODULE$.apply(Row.class)).groupWithinPartitionsBy(new CRRSamplerEstimator$$anonfun$14(this, tuple23._2$mcI$sp()), ClassTag$.MODULE$.Any()).map(new CRRSamplerEstimator$$anonfun$15(this, dataFrame2.schema().fieldIndex((String) $(labelCol()))), ClassTag$.MODULE$.Long()), Numeric$LongIsIntegral$.MODULE$).sum();
        }
        double d = count;
        double unboxToDouble = d * BoxesRunTime.unboxToDouble($(groupSampleRate()));
        double min = Math.min(1.0d, BoxesRunTime.unboxToInt($(expectedNumSamples())) / unboxToDouble);
        logInfo(new CRRSamplerEstimator$$anonfun$fit$1(this, d, unboxToDouble, min));
        return (CRRSamplerModel) copyValues(new CRRSamplerModel(), ParamMap$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ParamPair[]{itemSampleRate().$minus$greater(BoxesRunTime.boxToDouble(min))}))).setParent(this);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public CRRSamplerEstimator m125copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        return structType;
    }

    public CRRSamplerEstimator(String str) {
        this.uid = str;
        MLWritable.class.$init$(this);
        DefaultParamsWritable.class.$init$(this);
        HasInputCol.class.$init$(this);
        org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(new StringArrayParam((Params) this, "groupByColumns", "Grouping criteria for the evaluation."));
        HasLabelCol.class.$init$(this);
        CRRSamplerParams.Cclass.$init$(this);
        this.expectedNumSamples = new IntParam(this, "expectedNumSamples", "The expected number of samples in the result. Required.", new CRRSamplerEstimator$$anonfun$5(this));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{groupSampleRate().$minus$greater(BoxesRunTime.boxToDouble(1.0d))}));
    }

    public CRRSamplerEstimator() {
        this(Identifiable$.MODULE$.randomUID("crrSamplerEstimator"));
    }
}
