package org.broadinstitute.hellbender.engine.spark;

import com.google.common.base.Function;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.PeekingIterator;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.PartitionCoalescer;
import org.broadinstitute.hellbender.engine.Shard;
import org.broadinstitute.hellbender.engine.ShardBoundary;
import org.broadinstitute.hellbender.engine.ShardBoundaryShard;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import scala.Option;
import scala.Tuple2;
import scala.math.Ordering;
import scala.reflect.ClassTag$;

/* loaded from: input_file:org/broadinstitute/hellbender/engine/spark/SparkSharder.class */
public class SparkSharder {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/engine/spark/SparkSharder$KeyPartitioner.class */
    public static class KeyPartitioner extends Partitioner {
        private static final long serialVersionUID = 1;
        private int numPartitions;

        public KeyPartitioner(int i) {
            this.numPartitions = i;
        }

        public int numPartitions() {
            return this.numPartitions;
        }

        public int getPartition(Object obj) {
            return ((Integer) obj).intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/broadinstitute/hellbender/engine/spark/SparkSharder$PartitionLocatable.class */
    public static class PartitionLocatable<L extends Locatable> implements Locatable {
        private static final long serialVersionUID = 1;
        private final int partitionIndex;
        private final L interval;

        public PartitionLocatable(int i, L l) {
            this.partitionIndex = i;
            this.interval = l;
        }

        public int getPartitionIndex() {
            return this.partitionIndex;
        }

        public L getLocatable() {
            return this.interval;
        }

        public String getContig() {
            return this.interval.getContig();
        }

        public int getStart() {
            return this.interval.getStart();
        }

        public int getEnd() {
            return this.interval.getEnd();
        }

        public String toString() {
            return "PartitionLocatable{partitionIndex=" + this.partitionIndex + ", interval='" + this.interval + "'}";
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            PartitionLocatable partitionLocatable = (PartitionLocatable) obj;
            if (this.partitionIndex != partitionLocatable.partitionIndex) {
                return false;
            }
            return this.interval.equals(partitionLocatable.interval);
        }

        public int hashCode() {
            return (31 * this.partitionIndex) + this.interval.hashCode();
        }
    }

    public static <L extends Locatable> JavaRDD<Shard<L>> shard(JavaSparkContext javaSparkContext, JavaRDD<L> javaRDD, Class<L> cls, SAMSequenceDictionary sAMSequenceDictionary, List<ShardBoundary> list, int i) {
        return shard(javaSparkContext, javaRDD, cls, sAMSequenceDictionary, list, i, false);
    }

    public static <L extends Locatable> JavaRDD<Shard<L>> shard(JavaSparkContext javaSparkContext, JavaRDD<L> javaRDD, Class<L> cls, SAMSequenceDictionary sAMSequenceDictionary, List<ShardBoundary> list, int i, boolean z) {
        List list2 = (List) list.stream().map(shardBoundary -> {
            return new ShardBoundary(shardBoundary.getInterval(), shardBoundary.getPaddedInterval()) { // from class: org.broadinstitute.hellbender.engine.spark.SparkSharder.1
                private static final long serialVersionUID = 1;

                @Override // org.broadinstitute.hellbender.engine.ShardBoundary
                public String getContig() {
                    return getPaddedInterval().getContig();
                }

                @Override // org.broadinstitute.hellbender.engine.ShardBoundary
                public int getStart() {
                    return getPaddedInterval().getStart();
                }

                @Override // org.broadinstitute.hellbender.engine.ShardBoundary
                public int getEnd() {
                    return getPaddedInterval().getEnd();
                }
            };
        }).collect(Collectors.toList());
        if (!z) {
            return joinOverlapping(javaSparkContext, javaRDD, cls, sAMSequenceDictionary, list2, i, new MapFunction<Tuple2<ShardBoundary, Iterable<L>>, Shard<L>>() { // from class: org.broadinstitute.hellbender.engine.spark.SparkSharder.2
                private static final long serialVersionUID = 1;

                public Shard<L> call(Tuple2<ShardBoundary, Iterable<L>> tuple2) {
                    return new ShardBoundaryShard((ShardBoundary) tuple2._1(), (Iterable) tuple2._2());
                }
            });
        }
        Broadcast broadcast = javaSparkContext.broadcast(OverlapDetector.create(list2));
        return javaRDD.flatMapToPair(locatable -> {
            return ((List) ((OverlapDetector) broadcast.getValue()).getOverlaps(locatable).stream().map(shardBoundary2 -> {
                return new Tuple2(shardBoundary2, locatable);
            }).collect(Collectors.toList())).iterator();
        }).groupByKey().map(tuple2 -> {
            return new ShardBoundaryShard((ShardBoundary) tuple2._1(), (Iterable) tuple2._2());
        });
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext javaSparkContext, JavaRDD<L> javaRDD, Class<L> cls, SAMSequenceDictionary sAMSequenceDictionary, List<I> list, int i, MapFunction<Tuple2<I, Iterable<L>>, T> mapFunction) {
        return joinOverlapping(javaSparkContext, javaRDD, cls, sAMSequenceDictionary, list, i, (it, it2) -> {
            return Iterators.transform(locatablesPerShard(it, it2, sAMSequenceDictionary, i), new Function<Tuple2<I, Iterable<L>>, T>() { // from class: org.broadinstitute.hellbender.engine.spark.SparkSharder.3
                @Nullable
                public T apply(@Nullable Tuple2<I, Iterable<L>> tuple2) {
                    try {
                        return (T) mapFunction.call(tuple2);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            });
        });
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext javaSparkContext, JavaRDD<L> javaRDD, Class<L> cls, SAMSequenceDictionary sAMSequenceDictionary, List<I> list, int i, FlatMapFunction2<Iterator<L>, Iterator<I>, T> flatMapFunction2) {
        List<PartitionLocatable<SimpleInterval>> computePartitionReadExtents = computePartitionReadExtents(javaRDD, sAMSequenceDictionary, i);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < javaRDD.getNumPartitions(); i2++) {
            arrayList.add(Integer.valueOf(i2));
        }
        OverlapDetector create = OverlapDetector.create(computePartitionReadExtents);
        ArrayList arrayList2 = new ArrayList();
        int i3 = 0;
        for (I i4 : list) {
            int[] array = create.getOverlaps(i4).stream().mapToInt((v0) -> {
                return v0.getPartitionIndex();
            }).toArray();
            if (array.length == 0) {
                arrayList2.add(new PartitionLocatable(i3, i4));
            } else {
                Arrays.sort(array);
                int i5 = array[0];
                int i6 = array[array.length - 1];
                arrayList2.add(new PartitionLocatable(i5, i4));
                if (i6 > ((Integer) arrayList.get(i5)).intValue()) {
                    arrayList.set(i5, Integer.valueOf(i6));
                }
                i3 = i5;
            }
        }
        return coalesce(javaRDD, cls, new RangePartitionCoalescer(arrayList)).zipPartitions(javaSparkContext.parallelize(arrayList2).mapToPair(partitionLocatable -> {
            return new Tuple2(Integer.valueOf(partitionLocatable.getPartitionIndex()), partitionLocatable.getLocatable());
        }).partitionBy(new KeyPartitioner(javaRDD.getNumPartitions())).values(), flatMapFunction2);
    }

    static <L extends Locatable, I extends Locatable> Iterator<Tuple2<I, Iterable<L>>> locatablesPerShard(Iterator<L> it, Iterator<I> it2, final SAMSequenceDictionary sAMSequenceDictionary, final int i) {
        if (!it2.hasNext()) {
            return Collections.emptyIterator();
        }
        final PeekingIterator peekingIterator = Iterators.peekingIterator(it);
        final PeekingIterator peekingIterator2 = Iterators.peekingIterator(it2);
        return new AbstractIterator<Tuple2<I, Iterable<L>>>() { // from class: org.broadinstitute.hellbender.engine.spark.SparkSharder.4
            Locatable currentShard;
            Locatable nextShard;
            List<L> currentLocatables;
            List<L> nextLocatables;

            {
                this.currentShard = (Locatable) peekingIterator2.next();
                this.nextShard = peekingIterator2.hasNext() ? (Locatable) peekingIterator2.next() : null;
                this.currentLocatables = Lists.newArrayList();
                this.nextLocatables = Lists.newArrayList();
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* renamed from: computeNext, reason: merged with bridge method [inline-methods] */
            public Tuple2<I, Iterable<L>> m29computeNext() {
                int end;
                if (this.currentShard == null) {
                    return (Tuple2) endOfData();
                }
                while (peekingIterator.hasNext() && !SparkSharder.toRightOf(this.currentShard, (Locatable) peekingIterator.peek(), sAMSequenceDictionary)) {
                    Locatable locatable = (Locatable) peekingIterator.next();
                    if (locatable.getContig() != null && (end = (locatable.getEnd() - locatable.getStart()) + 1) > i) {
                        throw new UserException(String.format("Max size of locatable exceeded. Max size is %s, but locatable size is %s. Try increasing shard size and/or padding. Locatable: %s", Integer.valueOf(i), Integer.valueOf(end), locatable));
                    }
                    if (IntervalUtils.overlaps(this.currentShard, locatable)) {
                        this.currentLocatables.add(locatable);
                    }
                    if (this.nextShard != null && IntervalUtils.overlaps(this.nextShard, locatable)) {
                        this.nextLocatables.add(locatable);
                    }
                }
                Tuple2<I, Iterable<L>> tuple2 = new Tuple2<>(this.currentShard, this.currentLocatables);
                this.currentShard = this.nextShard;
                this.nextShard = peekingIterator2.hasNext() ? (Locatable) peekingIterator2.next() : null;
                this.currentLocatables = this.nextLocatables;
                this.nextLocatables = Lists.newArrayList();
                return tuple2;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <I extends Locatable, L extends Locatable> boolean toRightOf(I i, L l, SAMSequenceDictionary sAMSequenceDictionary) {
        int sequenceIndex = sAMSequenceDictionary.getSequenceIndex(i.getContig());
        int sequenceIndex2 = sAMSequenceDictionary.getSequenceIndex(l.getContig());
        return (sequenceIndex == sequenceIndex2 && i.getEnd() < l.getStart()) || sequenceIndex < sequenceIndex2;
    }

    static <L extends Locatable> List<PartitionLocatable<SimpleInterval>> computePartitionReadExtents(JavaRDD<L> javaRDD, SAMSequenceDictionary sAMSequenceDictionary, int i) {
        Locatable locatable;
        int size;
        List collect = javaRDD.mapPartitions(it -> {
            return ImmutableList.of(new PartitionLocatable(-1, it.hasNext() ? (Locatable) it.next() : null)).iterator();
        }).collect();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < collect.size(); i2++) {
            Locatable locatable2 = ((PartitionLocatable) collect.get(i2)).getLocatable();
            if (locatable2 != null) {
                arrayList.add(new PartitionLocatable(i2, locatable2));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            PartitionLocatable partitionLocatable = (PartitionLocatable) arrayList.get(i3);
            int partitionIndex = partitionLocatable.getPartitionIndex();
            Locatable locatable3 = partitionLocatable.getLocatable();
            int sequenceIndex = sAMSequenceDictionary.getSequenceIndex(locatable3.getContig());
            Utils.validate(sequenceIndex != -1, "Contig not found in sequence dictionary: " + locatable3.getContig());
            if (i3 < arrayList.size() - 1) {
                locatable = (Locatable) arrayList.get(i3 + 1);
                size = sAMSequenceDictionary.getSequenceIndex(locatable.getContig());
                Utils.validate(size != -1, "Contig not found in sequence dictionary: " + locatable.getContig());
            } else {
                locatable = null;
                size = sAMSequenceDictionary.getSequences().size();
            }
            if (sequenceIndex == size) {
                addPartitionReadExtent(arrayList2, partitionIndex, locatable3.getContig(), locatable3.getStart(), locatable.getStart() + i);
            } else {
                SAMSequenceRecord sequence = sAMSequenceDictionary.getSequence(locatable3.getContig());
                Utils.validate(sequence != null, "Contig not found in sequence dictionary: " + locatable3.getContig());
                addPartitionReadExtent(arrayList2, partitionIndex, locatable3.getContig(), locatable3.getStart(), sequence.getSequenceLength());
                for (int i4 = sequenceIndex + 1; i4 < size; i4++) {
                    SAMSequenceRecord sequence2 = sAMSequenceDictionary.getSequence(i4);
                    Utils.validate(sequence2 != null, "Contig index not found in sequence dictionary: " + i4);
                    addPartitionReadExtent(arrayList2, partitionIndex, sequence2.getSequenceName(), 1, sequence2.getSequenceLength());
                }
                if (locatable != null) {
                    addPartitionReadExtent(arrayList2, partitionIndex, locatable.getContig(), 1, locatable.getStart() + i);
                }
            }
        }
        return arrayList2;
    }

    private static void addPartitionReadExtent(List<PartitionLocatable<SimpleInterval>> list, int i, String str, int i2, int i3) {
        list.add(new PartitionLocatable<>(i, new SimpleInterval(str, i2, i3)));
    }

    private static <T> JavaRDD<T> coalesce(JavaRDD<T> javaRDD, Class<T> cls, PartitionCoalescer partitionCoalescer) {
        return new JavaRDD<>(javaRDD.rdd().coalesce(javaRDD.getNumPartitions(), false, Option.apply(partitionCoalescer), (Ordering) null), ClassTag$.MODULE$.apply(cls));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1449070619:
                if (implMethodName.equals("lambda$shard$1306daa1$1")) {
                    z = true;
                    break;
                }
                break;
            case 535122287:
                if (implMethodName.equals("lambda$shard$ee2828c6$1")) {
                    z = 3;
                    break;
                }
                break;
            case 868296012:
                if (implMethodName.equals("lambda$computePartitionReadExtents$d5de901f$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1235040057:
                if (implMethodName.equals("lambda$joinOverlapping$f37eb44a$1")) {
                    z = false;
                    break;
                }
                break;
            case 1770434288:
                if (implMethodName.equals("lambda$joinOverlapping$9f035611$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/SparkSharder") && serializedLambda.getImplMethodSignature().equals("(Lhtsjdk/samtools/SAMSequenceDictionary;ILorg/apache/spark/api/java/function/MapFunction;Ljava/util/Iterator;Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    SAMSequenceDictionary sAMSequenceDictionary = (SAMSequenceDictionary) serializedLambda.getCapturedArg(0);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    MapFunction mapFunction = (MapFunction) serializedLambda.getCapturedArg(2);
                    return (it, it2) -> {
                        return Iterators.transform(locatablesPerShard(it, it2, sAMSequenceDictionary, intValue), new Function<Tuple2<I, Iterable<L>>, T>() { // from class: org.broadinstitute.hellbender.engine.spark.SparkSharder.3
                            @Nullable
                            public T apply(@Nullable Tuple2<I, Iterable<L>> tuple2) {
                                try {
                                    return (T) mapFunction.call(tuple2);
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/SparkSharder") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lorg/broadinstitute/hellbender/engine/Shard;")) {
                    return tuple2 -> {
                        return new ShardBoundaryShard((ShardBoundary) tuple2._1(), (Iterable) tuple2._2());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lscala/Tuple2;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/SparkSharder") && serializedLambda.getImplMethodSignature().equals("(Lorg/broadinstitute/hellbender/engine/spark/SparkSharder$PartitionLocatable;)Lscala/Tuple2;")) {
                    return partitionLocatable -> {
                        return new Tuple2(Integer.valueOf(partitionLocatable.getPartitionIndex()), partitionLocatable.getLocatable());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/SparkSharder") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/broadcast/Broadcast;Lhtsjdk/samtools/util/Locatable;)Ljava/util/Iterator;")) {
                    Broadcast broadcast = (Broadcast) serializedLambda.getCapturedArg(0);
                    return locatable -> {
                        return ((List) ((OverlapDetector) broadcast.getValue()).getOverlaps(locatable).stream().map(shardBoundary2 -> {
                            return new Tuple2(shardBoundary2, locatable);
                        }).collect(Collectors.toList())).iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/SparkSharder") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    return it3 -> {
                        return ImmutableList.of(new PartitionLocatable(-1, it3.hasNext() ? (Locatable) it3.next() : null)).iterator();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
