package org.apache.spark.ml.odkl;

import com.github.fommil.netlib.BLAS;
import java.io.IOException;
import java.util.concurrent.ThreadLocalRandom;
import odkl.analysis.spark.util.RDDOperations$;
import odkl.analysis.spark.util.collection.CompactBuffer;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.odkl.CRRSamplerParams;
import org.apache.spark.ml.odkl.HasGroupByColumns;
import org.apache.spark.ml.odkl.HasNetlibBlas;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple3;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: CRRSampler.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ua\u0001B\u0001\u0003\u00015\u0011qb\u0011*S'\u0006l\u0007\u000f\\3s\u001b>$W\r\u001c\u0006\u0003\u0007\u0011\tAa\u001c3lY*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001'\u0015\u0001a\u0002\u0006\u000e\u001e!\ry\u0001CE\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\u0006\u001b>$W\r\u001c\t\u0003'\u0001i\u0011A\u0001\t\u0003+ai\u0011A\u0006\u0006\u0003/\u0011\tA!\u001e;jY&\u0011\u0011D\u0006\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:Xe&$\u0018M\u00197f!\t\u00192$\u0003\u0002\u001d\u0005\t\u00012I\u0015*TC6\u0004H.\u001a:QCJ\fWn\u001d\t\u0003'yI!a\b\u0002\u0003\u001b!\u000b7OT3uY&\u0014'\t\\1t\u0011!\t\u0003A!b\u0001\n\u0003\u0012\u0013aA;jIV\t1\u0005\u0005\u0002%U9\u0011Q\u0005K\u0007\u0002M)\tq%A\u0003tG\u0006d\u0017-\u0003\u0002*M\u00051\u0001K]3eK\u001aL!a\u000b\u0017\u0003\rM#(/\u001b8h\u0015\tIc\u0005\u0003\u0005/\u0001\t\u0005\t\u0015!\u0003$\u0003\u0011)\u0018\u000e\u001a\u0011\t\u000bA\u0002A\u0011A\u0019\u0002\rqJg.\u001b;?)\t\u0011\"\u0007C\u0003\"_\u0001\u00071\u0005C\u00031\u0001\u0011\u0005A\u0007F\u0001\u0013\u0011\u00151\u0004\u0001\"\u00118\u0003%!(/\u00198tM>\u0014X\u000e\u0006\u00029}A\u0011\u0011\bP\u0007\u0002u)\u00111HB\u0001\u0004gFd\u0017BA\u001f;\u0005%!\u0015\r^1Ge\u0006lW\rC\u0003@k\u0001\u0007\u0001(A\u0004eCR\f7/\u001a;\t\u000b\u0005\u0003A\u0011\u0001\"\u0002\u0015M\fW\u000e\u001d7f%><8\u000f\u0006\u0003D%RK\u0006c\u0001#M\u001f:\u0011QI\u0013\b\u0003\r&k\u0011a\u0012\u0006\u0003\u00112\ta\u0001\u0010:p_Rt\u0014\"A\u0014\n\u0005-3\u0013a\u00029bG.\fw-Z\u0005\u0003\u001b:\u0013\u0001\"\u0013;fe\u0006$xN\u001d\u0006\u0003\u0017\u001a\u0002\"!\u000f)\n\u0005ES$a\u0001*po\")1\u000b\u0011a\u0001\u0007\u0006!!o\\<t\u0011\u0015)\u0006\t1\u0001W\u0003)a\u0017MY3m\u0013:$W\r\u001f\t\u0003K]K!\u0001\u0017\u0014\u0003\u0007%sG\u000fC\u0003[\u0001\u0002\u0007a+\u0001\u0007gK\u0006$XO]3J]\u0012,\u0007\u0010C\u0003]\u0001\u0011\u0005Q,A\u0007d_VtG/\u001a:TC6\u0004H.\u001a\u000b\u0005\u0007z\u0003W\rC\u0003`7\u0002\u00071)\u0001\u0004t_V\u00148-\u001a\u0005\u0006Cn\u0003\rAY\u0001\fG>,h\u000e^3s!\u0006\u0014H\u000fE\u0002EG>K!\u0001\u001a(\u0003\u0007M+\u0017\u000fC\u0003[7\u0002\u0007a\u000bC\u0003h\u0001\u0011\u0005\u0001.\u0001\u0006qC&\u00148+Y7qY\u0016$BaT5lY\")!N\u001aa\u0001\u001f\u000611/Y7qY\u0016DQ!\u00194A\u0002\tDQA\u00174A\u0002YCQA\u001c\u0001\u0005B=\fAaY8qsR\u0011!\u0003\u001d\u0005\u0006c6\u0004\rA]\u0001\u0006Kb$(/\u0019\t\u0003gZl\u0011\u0001\u001e\u0006\u0003k\u0012\tQ\u0001]1sC6L!a\u001e;\u0003\u0011A\u000b'/Y7NCBDQ!\u001f\u0001\u0005Bi\fq\u0002\u001e:b]N4wN]7TG\",W.\u0019\u000b\u0004w\u0006\r\u0001C\u0001?��\u001b\u0005i(B\u0001@;\u0003\u0015!\u0018\u0010]3t\u0013\r\t\t! \u0002\u000b'R\u0014Xo\u0019;UsB,\u0007BBA\u0003q\u0002\u000710\u0001\u0004tG\",W.\u0019\u0015\u0004q\u0006%\u0001\u0003BA\u0006\u0003#i!!!\u0004\u000b\u0007\u0005=a!\u0001\u0006b]:|G/\u0019;j_:LA!a\u0005\u0002\u000e\taA)\u001a<fY>\u0004XM]!qS\u0002")
/* loaded from: input_file:org/apache/spark/ml/odkl/CRRSamplerModel.class */
public class CRRSamplerModel extends Model<CRRSamplerModel> implements DefaultParamsWritable, CRRSamplerParams, HasNetlibBlas {
    private final String uid;
    private final DoubleParam groupSampleRate;
    private final DoubleParam itemSampleRate;
    private final DoubleParam rankingPower;
    private final IntParam shuffleToPartitions;
    private final Param<String> labelCol;
    private final StringArrayParam groupByColumns;
    private final Param<String> inputCol;

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS f2jBLAS() {
        return HasNetlibBlas.Cclass.f2jBLAS(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS blas() {
        return HasNetlibBlas.Cclass.blas(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void dscal(double d, double[] dArr) {
        HasNetlibBlas.Cclass.dscal(this, d, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.axpy(this, d, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, Vector vector, double[] dArr) {
        HasNetlibBlas.Cclass.axpy(this, d, vector, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void copy(double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.copy(this, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam groupSampleRate() {
        return this.groupSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam itemSampleRate() {
        return this.itemSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam rankingPower() {
        return this.rankingPower;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public IntParam shuffleToPartitions() {
        return this.shuffleToPartitions;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$groupSampleRate_$eq(DoubleParam doubleParam) {
        this.groupSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$itemSampleRate_$eq(DoubleParam doubleParam) {
        this.itemSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$rankingPower_$eq(DoubleParam doubleParam) {
        this.rankingPower = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$shuffleToPartitions_$eq(IntParam intParam) {
        this.shuffleToPartitions = intParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setGroupSampleRate(double d) {
        return CRRSamplerParams.Cclass.setGroupSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setItemSampleRate(double d) {
        return CRRSamplerParams.Cclass.setItemSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setRankingPower(double d) {
        return CRRSamplerParams.Cclass.setRankingPower(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setShufflerToPartitions(int i) {
        return CRRSamplerParams.Cclass.setShufflerToPartitions(this, i);
    }

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    public final String getLabelCol() {
        return HasLabelCol.class.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final StringArrayParam groupByColumns() {
        return this.groupByColumns;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final void org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(StringArrayParam stringArrayParam) {
        this.groupByColumns = stringArrayParam;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public HasGroupByColumns setGroupByColumns(Seq<String> seq) {
        return HasGroupByColumns.Cclass.setGroupByColumns(this, seq);
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    public final String getInputCol() {
        return HasInputCol.class.getInputCol(this);
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public String uid() {
        return this.uid;
    }

    public DataFrame transform(DataFrame dataFrame) {
        Tuple3 tuple3;
        if (!isDefined(groupByColumns()) || ((String[]) $(groupByColumns())).length <= 0) {
            String s = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_tmpKey"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{uid()}));
            DataFrame withColumn = dataFrame.withColumn(s, functions$.MODULE$.lit(s));
            tuple3 = new Tuple3(withColumn, BoxesRunTime.boxToInteger(withColumn.schema().fieldIndex(s)), new Some(s));
        } else if (((String[]) $(groupByColumns())).length == 1) {
            tuple3 = new Tuple3(dataFrame, BoxesRunTime.boxToInteger(dataFrame.schema().fieldIndex((String) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).head())), None$.MODULE$);
        } else {
            Column struct = functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).map(new CRRSamplerModel$$anonfun$6(this, dataFrame), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
            String s2 = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_tmpKey"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{uid()}));
            DataFrame withColumn2 = dataFrame.withColumn(s2, struct);
            tuple3 = new Tuple3(withColumn2, BoxesRunTime.boxToInteger(withColumn2.schema().fieldIndex(s2)), new Some(s2));
        }
        Tuple3 tuple32 = tuple3;
        if (tuple32 == null) {
            throw new MatchError(tuple32);
        }
        Tuple3 tuple33 = new Tuple3((DataFrame) tuple32._1(), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._2())), (Option) tuple32._3());
        DataFrame dataFrame2 = (DataFrame) tuple33._1();
        int unboxToInt = BoxesRunTime.unboxToInt(tuple33._2());
        Option option = (Option) tuple33._3();
        DataFrame createDataFrame = dataFrame.sqlContext().createDataFrame(RDDOperations$.MODULE$.ImplicitRDDDecorator(dataFrame2.rdd(), ClassTag$.MODULE$.apply(Row.class)).groupWithinPartitionsBy(new CRRSamplerModel$$anonfun$7(this, unboxToInt), ClassTag$.MODULE$.Any()).flatMap(new CRRSamplerModel$$anonfun$8(this, dataFrame2.schema().fieldIndex((String) $(inputCol())), dataFrame2.schema().fieldIndex((String) $(labelCol()))), ClassTag$.MODULE$.apply(Row.class)), dataFrame2.schema());
        DataFrame dataFrame3 = (DataFrame) option.map(new CRRSamplerModel$$anonfun$9(this, createDataFrame)).getOrElse(new CRRSamplerModel$$anonfun$10(this, createDataFrame));
        return (DataFrame) get(shuffleToPartitions()).map(new CRRSamplerModel$$anonfun$transform$1(this, dataFrame3)).getOrElse(new CRRSamplerModel$$anonfun$transform$2(this, dataFrame3));
    }

    public Iterator<Row> sampleRows(Iterator<Row> iterator, int i, int i2) {
        if (BoxesRunTime.unboxToDouble($(groupSampleRate())) < 1 && ThreadLocalRandom.current().nextDouble() > BoxesRunTime.unboxToDouble($(groupSampleRate()))) {
            return package$.MODULE$.Iterator().empty();
        }
        if (BoxesRunTime.unboxToDouble($(rankingPower())) <= 0) {
            return counterSample(iterator, Seq$.MODULE$.empty(), i2);
        }
        CompactBuffer compactBuffer = new CompactBuffer(ClassTag$.MODULE$.apply(Row.class));
        CompactBuffer compactBuffer2 = new CompactBuffer(ClassTag$.MODULE$.apply(Row.class));
        iterator.foreach(new CRRSamplerModel$$anonfun$sampleRows$1(this, i, compactBuffer, compactBuffer2));
        return (compactBuffer.isEmpty() || compactBuffer2.isEmpty()) ? package$.MODULE$.Iterator().empty() : counterSample(compactBuffer.iterator(), compactBuffer2, i2).$plus$plus(new CRRSamplerModel$$anonfun$sampleRows$2(this, i2, compactBuffer, compactBuffer2));
    }

    public Iterator<Row> counterSample(Iterator<Row> iterator, Seq<Row> seq, int i) {
        Iterator<Row> filter = BoxesRunTime.unboxToDouble($(itemSampleRate())) < ((double) 1) ? iterator.filter(new CRRSamplerModel$$anonfun$11(this)) : iterator;
        return BoxesRunTime.unboxToDouble($(rankingPower())) <= ((double) 0) ? filter : BoxesRunTime.unboxToDouble($(rankingPower())) < ((double) 1) ? filter.map(new CRRSamplerModel$$anonfun$counterSample$1(this, seq, i)) : filter.map(new CRRSamplerModel$$anonfun$counterSample$2(this, seq, i));
    }

    public Row pairSample(Row row, Seq<Row> seq, int i) {
        return Row$.MODULE$.fromSeq(Predef$.MODULE$.genericWrapArray((Object[]) Array$.MODULE$.tabulate(row.length(), new CRRSamplerModel$$anonfun$12(this, row, i, (Vector) ((Row) seq.apply(ThreadLocalRandom.current().nextInt(seq.size()))).getAs(i)), ClassTag$.MODULE$.Any())));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public CRRSamplerModel m148copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        return structType;
    }

    public CRRSamplerModel(String str) {
        this.uid = str;
        MLWritable.class.$init$(this);
        DefaultParamsWritable.class.$init$(this);
        HasInputCol.class.$init$(this);
        org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(new StringArrayParam((Params) this, "groupByColumns", "Grouping criteria for the evaluation."));
        HasLabelCol.class.$init$(this);
        CRRSamplerParams.Cclass.$init$(this);
        HasNetlibBlas.Cclass.$init$(this);
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{groupSampleRate().$minus$greater(BoxesRunTime.boxToDouble(1.0d)), itemSampleRate().$minus$greater(BoxesRunTime.boxToDouble(1.0d))}));
    }

    public CRRSamplerModel() {
        this(Identifiable$.MODULE$.randomUID("crrSampler"));
    }
}
