package org.apache.spark.ml.odkl;

import com.github.fommil.netlib.BLAS;
import java.io.IOException;
import java.util.concurrent.ThreadLocalRandom;
import odkl.analysis.spark.util.RDDOperations$;
import odkl.analysis.spark.util.collection.CompactBuffer;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.odkl.CRRSamplerParams;
import org.apache.spark.ml.odkl.HasGroupByColumns;
import org.apache.spark.ml.odkl.HasNetlibBlas;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple3;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: CRRSampler.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005%c\u0001B\u0001\u0003\u00015\u0011qb\u0011*S'\u0006l\u0007\u000f\\3s\u001b>$W\r\u001c\u0006\u0003\u0007\u0011\tAa\u001c3lY*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001'\u0015\u0001a\u0002\u0006\u000e\u001e!\ry\u0001CE\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\u0006\u001b>$W\r\u001c\t\u0003'\u0001i\u0011A\u0001\t\u0003+ai\u0011A\u0006\u0006\u0003/\u0011\tA!\u001e;jY&\u0011\u0011D\u0006\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:Xe&$\u0018M\u00197f!\t\u00192$\u0003\u0002\u001d\u0005\t\u00012I\u0015*TC6\u0004H.\u001a:QCJ\fWn\u001d\t\u0003'yI!a\b\u0002\u0003\u001b!\u000b7OT3uY&\u0014'\t\\1t\u0011!\t\u0003A!b\u0001\n\u0003\u0012\u0013aA;jIV\t1\u0005\u0005\u0002%U9\u0011Q\u0005K\u0007\u0002M)\tq%A\u0003tG\u0006d\u0017-\u0003\u0002*M\u00051\u0001K]3eK\u001aL!a\u000b\u0017\u0003\rM#(/\u001b8h\u0015\tIc\u0005\u0003\u0005/\u0001\t\u0005\t\u0015!\u0003$\u0003\u0011)\u0018\u000e\u001a\u0011\t\u000bA\u0002A\u0011A\u0019\u0002\rqJg.\u001b;?)\t\u0011\"\u0007C\u0003\"_\u0001\u00071\u0005C\u00031\u0001\u0011\u0005A\u0007F\u0001\u0013\u0011\u00151\u0004\u0001\"\u00118\u0003%!(/\u00198tM>\u0014X\u000e\u0006\u00029\u0019B\u0011\u0011(\u0013\b\u0003u\u0019s!a\u000f#\u000f\u0005q\u001aeBA\u001fC\u001d\tq\u0014)D\u0001@\u0015\t\u0001E\"\u0001\u0004=e>|GOP\u0005\u0002\u0017%\u0011\u0011BC\u0005\u0003\u000f!I!!\u0012\u0004\u0002\u0007M\fH.\u0003\u0002H\u0011\u00069\u0001/Y2lC\u001e,'BA#\u0007\u0013\tQ5JA\u0005ECR\fgI]1nK*\u0011q\t\u0013\u0005\u0006\u001bV\u0002\rAT\u0001\bI\u0006$\u0018m]3ua\tyU\u000bE\u0002Q#Nk\u0011\u0001S\u0005\u0003%\"\u0013q\u0001R1uCN,G\u000f\u0005\u0002U+2\u0001A!\u0003,M\u0003\u0003\u0005\tQ!\u0001X\u0005\ryF%M\t\u00031n\u0003\"!J-\n\u0005i3#a\u0002(pi\"Lgn\u001a\t\u0003KqK!!\u0018\u0014\u0003\u0007\u0005s\u0017\u0010C\u0003`\u0001\u0011\u0005\u0001-\u0001\u0006tC6\u0004H.\u001a*poN$B!\u00197ogB\u0019!MZ5\u000f\u0005\r,gB\u0001 e\u0013\u00059\u0013BA$'\u0013\t9\u0007N\u0001\u0005Ji\u0016\u0014\u0018\r^8s\u0015\t9e\u0005\u0005\u0002QU&\u00111\u000e\u0013\u0002\u0004%><\b\"B7_\u0001\u0004\t\u0017\u0001\u0002:poNDQa\u001c0A\u0002A\f!\u0002\\1cK2Le\u000eZ3y!\t)\u0013/\u0003\u0002sM\t\u0019\u0011J\u001c;\t\u000bQt\u0006\u0019\u00019\u0002\u0019\u0019,\u0017\r^;sK&sG-\u001a=\t\u000bY\u0004A\u0011A<\u0002\u001b\r|WO\u001c;feN\u000bW\u000e\u001d7f)\u0011\t\u0007P_@\t\u000be,\b\u0019A1\u0002\rM|WO]2f\u0011\u0015YX\u000f1\u0001}\u0003-\u0019w.\u001e8uKJ\u0004\u0016M\u001d;\u0011\u0007\tl\u0018.\u0003\u0002\u007fQ\n\u00191+Z9\t\u000bQ,\b\u0019\u00019\t\u000f\u0005\r\u0001\u0001\"\u0001\u0002\u0006\u0005Q\u0001/Y5s'\u0006l\u0007\u000f\\3\u0015\u000f%\f9!a\u0003\u0002\u000e!9\u0011\u0011BA\u0001\u0001\u0004I\u0017AB:b[BdW\r\u0003\u0004|\u0003\u0003\u0001\r\u0001 \u0005\u0007i\u0006\u0005\u0001\u0019\u00019\t\u000f\u0005E\u0001\u0001\"\u0011\u0002\u0014\u0005!1m\u001c9z)\r\u0011\u0012Q\u0003\u0005\t\u0003/\ty\u00011\u0001\u0002\u001a\u0005)Q\r\u001f;sCB!\u00111DA\u0011\u001b\t\tiBC\u0002\u0002 \u0011\tQ\u0001]1sC6LA!a\t\u0002\u001e\tA\u0001+\u0019:b[6\u000b\u0007\u000fC\u0004\u0002(\u0001!\t%!\u000b\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$B!a\u000b\u00028A!\u0011QFA\u001a\u001b\t\tyCC\u0002\u00022!\u000bQ\u0001^=qKNLA!!\u000e\u00020\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005e\u0012Q\u0005a\u0001\u0003W\taa]2iK6\f\u0007\u0006BA\u0013\u0003{\u0001B!a\u0010\u0002F5\u0011\u0011\u0011\t\u0006\u0004\u0003\u00072\u0011AC1o]>$\u0018\r^5p]&!\u0011qIA!\u00051!UM^3m_B,'/\u00119j\u0001")
/* loaded from: input_file:org/apache/spark/ml/odkl/CRRSamplerModel.class */
public class CRRSamplerModel extends Model<CRRSamplerModel> implements DefaultParamsWritable, CRRSamplerParams, HasNetlibBlas {
    private final String uid;
    private final DoubleParam groupSampleRate;
    private final DoubleParam itemSampleRate;
    private final DoubleParam rankingPower;
    private final IntParam shuffleToPartitions;
    private final Param<String> labelCol;
    private final StringArrayParam groupByColumns;
    private final Param<String> inputCol;

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS f2jBLAS() {
        return HasNetlibBlas.Cclass.f2jBLAS(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS blas() {
        return HasNetlibBlas.Cclass.blas(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void dscal(double d, double[] dArr) {
        HasNetlibBlas.Cclass.dscal(this, d, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.axpy(this, d, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, Vector vector, double[] dArr) {
        HasNetlibBlas.Cclass.axpy(this, d, vector, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void copy(double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.copy(this, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam groupSampleRate() {
        return this.groupSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam itemSampleRate() {
        return this.itemSampleRate;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public DoubleParam rankingPower() {
        return this.rankingPower;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public IntParam shuffleToPartitions() {
        return this.shuffleToPartitions;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$groupSampleRate_$eq(DoubleParam doubleParam) {
        this.groupSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$itemSampleRate_$eq(DoubleParam doubleParam) {
        this.itemSampleRate = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$rankingPower_$eq(DoubleParam doubleParam) {
        this.rankingPower = doubleParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$shuffleToPartitions_$eq(IntParam intParam) {
        this.shuffleToPartitions = intParam;
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setGroupSampleRate(double d) {
        return CRRSamplerParams.Cclass.setGroupSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setItemSampleRate(double d) {
        return CRRSamplerParams.Cclass.setItemSampleRate(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setRankingPower(double d) {
        return CRRSamplerParams.Cclass.setRankingPower(this, d);
    }

    @Override // org.apache.spark.ml.odkl.CRRSamplerParams
    public CRRSamplerParams setShufflerToPartitions(int i) {
        return CRRSamplerParams.Cclass.setShufflerToPartitions(this, i);
    }

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    public final String getLabelCol() {
        return HasLabelCol.class.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final StringArrayParam groupByColumns() {
        return this.groupByColumns;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public final void org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(StringArrayParam stringArrayParam) {
        this.groupByColumns = stringArrayParam;
    }

    @Override // org.apache.spark.ml.odkl.HasGroupByColumns
    public HasGroupByColumns setGroupByColumns(Seq<String> seq) {
        return HasGroupByColumns.Cclass.setGroupByColumns(this, seq);
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    public final String getInputCol() {
        return HasInputCol.class.getInputCol(this);
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public String uid() {
        return this.uid;
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Tuple3 tuple3;
        if (!isDefined(groupByColumns()) || ((String[]) $(groupByColumns())).length <= 0) {
            String s = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_tmpKey"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{uid()}));
            Dataset withColumn = dataset.withColumn(s, functions$.MODULE$.lit(s));
            tuple3 = new Tuple3(withColumn, BoxesRunTime.boxToInteger(withColumn.schema().fieldIndex(s)), new Some(s));
        } else if (((String[]) $(groupByColumns())).length == 1) {
            tuple3 = new Tuple3(dataset, BoxesRunTime.boxToInteger(dataset.schema().fieldIndex((String) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).head())), None$.MODULE$);
        } else {
            Column struct = functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.refArrayOps((Object[]) $(groupByColumns())).map(new CRRSamplerModel$$anonfun$6(this, dataset), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
            String s2 = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_tmpKey"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{uid()}));
            Dataset withColumn2 = dataset.withColumn(s2, struct);
            tuple3 = new Tuple3(withColumn2, BoxesRunTime.boxToInteger(withColumn2.schema().fieldIndex(s2)), new Some(s2));
        }
        Tuple3 tuple32 = tuple3;
        if (tuple32 == null) {
            throw new MatchError(tuple32);
        }
        Tuple3 tuple33 = new Tuple3((Dataset) tuple32._1(), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._2())), (Option) tuple32._3());
        Dataset dataset2 = (Dataset) tuple33._1();
        int unboxToInt = BoxesRunTime.unboxToInt(tuple33._2());
        Option option = (Option) tuple33._3();
        Dataset createDataFrame = dataset.sqlContext().createDataFrame(RDDOperations$.MODULE$.ImplicitRDDDecorator(dataset2.toDF().rdd(), ClassTag$.MODULE$.apply(Row.class)).groupWithinPartitionsBy(new CRRSamplerModel$$anonfun$7(this, unboxToInt), ClassTag$.MODULE$.Any()).flatMap(new CRRSamplerModel$$anonfun$8(this, dataset2.schema().fieldIndex((String) $(inputCol())), dataset2.schema().fieldIndex((String) $(labelCol()))), ClassTag$.MODULE$.apply(Row.class)), dataset2.schema());
        Dataset dataset3 = (Dataset) option.map(new CRRSamplerModel$$anonfun$9(this, createDataFrame)).getOrElse(new CRRSamplerModel$$anonfun$10(this, createDataFrame));
        return (Dataset) get(shuffleToPartitions()).map(new CRRSamplerModel$$anonfun$transform$1(this, dataset3)).getOrElse(new CRRSamplerModel$$anonfun$transform$2(this, dataset3));
    }

    public Iterator<Row> sampleRows(Iterator<Row> iterator, int i, int i2) {
        if (BoxesRunTime.unboxToDouble($(groupSampleRate())) < 1 && ThreadLocalRandom.current().nextDouble() > BoxesRunTime.unboxToDouble($(groupSampleRate()))) {
            return package$.MODULE$.Iterator().empty();
        }
        if (BoxesRunTime.unboxToDouble($(rankingPower())) <= 0) {
            return counterSample(iterator, Seq$.MODULE$.empty(), i2);
        }
        CompactBuffer compactBuffer = new CompactBuffer(ClassTag$.MODULE$.apply(Row.class));
        CompactBuffer compactBuffer2 = new CompactBuffer(ClassTag$.MODULE$.apply(Row.class));
        iterator.foreach(new CRRSamplerModel$$anonfun$sampleRows$1(this, i, compactBuffer, compactBuffer2));
        return (compactBuffer.isEmpty() || compactBuffer2.isEmpty()) ? package$.MODULE$.Iterator().empty() : counterSample(compactBuffer.iterator(), compactBuffer2, i2).$plus$plus(new CRRSamplerModel$$anonfun$sampleRows$2(this, i2, compactBuffer, compactBuffer2));
    }

    public Iterator<Row> counterSample(Iterator<Row> iterator, Seq<Row> seq, int i) {
        Iterator<Row> filter = BoxesRunTime.unboxToDouble($(itemSampleRate())) < ((double) 1) ? iterator.filter(new CRRSamplerModel$$anonfun$11(this)) : iterator;
        return BoxesRunTime.unboxToDouble($(rankingPower())) <= ((double) 0) ? filter : BoxesRunTime.unboxToDouble($(rankingPower())) < ((double) 1) ? filter.map(new CRRSamplerModel$$anonfun$counterSample$1(this, seq, i)) : filter.map(new CRRSamplerModel$$anonfun$counterSample$2(this, seq, i));
    }

    public Row pairSample(Row row, Seq<Row> seq, int i) {
        return Row$.MODULE$.fromSeq(Predef$.MODULE$.genericWrapArray((Object[]) Array$.MODULE$.tabulate(row.length(), new CRRSamplerModel$$anonfun$12(this, row, i, (Vector) ((Row) seq.apply(ThreadLocalRandom.current().nextInt(seq.size()))).getAs(i)), ClassTag$.MODULE$.Any())));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public CRRSamplerModel m148copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        return structType;
    }

    public CRRSamplerModel(String str) {
        this.uid = str;
        MLWritable.class.$init$(this);
        DefaultParamsWritable.class.$init$(this);
        HasInputCol.class.$init$(this);
        org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(new StringArrayParam((Params) this, "groupByColumns", "Grouping criteria for the evaluation."));
        HasLabelCol.class.$init$(this);
        CRRSamplerParams.Cclass.$init$(this);
        HasNetlibBlas.Cclass.$init$(this);
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{groupSampleRate().$minus$greater(BoxesRunTime.boxToDouble(1.0d)), itemSampleRate().$minus$greater(BoxesRunTime.boxToDouble(1.0d))}));
    }

    public CRRSamplerModel() {
        this(Identifiable$.MODULE$.randomUID("crrSampler"));
    }
}
