package org.broadinstitute.hellbender.tools.spark.pathseq;

import htsjdk.samtools.SAMFileHeader;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
import org.broadinstitute.hellbender.cmdline.programgroups.MetagenomicsProgramGroup;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterEmptyLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterFileLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSScoreFileLogger;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;
import scala.Tuple2;

@CommandLineProgramProperties(summary = "Combined tool that performs all PathSeq steps: read filtering, microbe reference alignment and abundance scoring", oneLineSummary = "Combined tool that performs all steps: read filtering, microbe reference alignment, and abundance scoring", programGroup = MetagenomicsProgramGroup.class)
@DocumentedFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/pathseq/PathSeqPipelineSpark.class */
public class PathSeqPipelineSpark extends GATKSparkTool {
    private static final long serialVersionUID = 1;
    public static final String READS_PER_PARTITION_LONG_NAME = "pipeline-reads-per-partition";
    public static final String READS_PER_PARTITION_SHORT_NAME = "pipeline-reads-per-partition";

    @ArgumentCollection
    public PSFilterArgumentCollection filterArgs = new PSFilterArgumentCollection();

    @ArgumentCollection
    public PSBwaArgumentCollection bwaArgs = new PSBwaArgumentCollection();

    @ArgumentCollection
    public PSScoreArgumentCollection scoreArgs = new PSScoreArgumentCollection();

    @Argument(doc = "Output BAM", fullName = "output", shortName = "O", optional = true)
    public String outputPath = null;

    @Argument(doc = "Number of reads per partition to use for alignment and scoring.", fullName = "pipeline-reads-per-partition", shortName = "pipeline-reads-per-partition", optional = true, minValue = 100.0d)
    public int readsPerPartition = 5000;

    @Argument(doc = "Number of reads per partition for output. Use this to control the number of sharded BAMs (not --num-reducers).", fullName = "readsPerPartitionOutput", optional = true, minValue = 100.0d, minRecommendedValue = 100000.0d)
    public int readsPerPartitionOutput = 1000000;

    private static JavaRDD<GATKRead> repartitionPairedReads(JavaRDD<GATKRead> javaRDD, int i, long j) {
        int i2 = 1 + ((int) (j / i));
        return javaRDD.mapPartitions(it -> {
            return pairPartitionReads(it, i2);
        }).repartition(i).flatMap((v0) -> {
            return v0.iterator();
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Iterator<List<GATKRead>> pairPartitionReads(Iterator<GATKRead> it, int i) {
        ArrayList arrayList = new ArrayList(i / 2);
        while (it.hasNext()) {
            ArrayList arrayList2 = new ArrayList(2);
            arrayList2.add(it.next());
            if (!it.hasNext()) {
                throw new GATKException("Odd number of read pairs in paired reads partition");
            }
            arrayList2.add(it.next());
            if (!((GATKRead) arrayList2.get(0)).getName().equals(((GATKRead) arrayList2.get(1)).getName())) {
                throw new GATKException("Pair did not have the same name in a paired reads partition");
            }
            arrayList.add(arrayList2);
        }
        return arrayList.iterator();
    }

    @Override // org.broadinstitute.hellbender.engine.spark.GATKSparkTool
    public boolean requiresReads() {
        return true;
    }

    @Override // org.broadinstitute.hellbender.engine.spark.GATKSparkTool
    protected void runTool(JavaSparkContext javaSparkContext) {
        this.filterArgs.doReadFilterArgumentWarnings((GATKReadFilterPluginDescriptor) getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class), this.logger);
        SAMFileHeader checkAndClearHeaderSequences = PSUtils.checkAndClearHeaderSequences(getHeaderForReads(), this.filterArgs, this.logger);
        if (this.numReducers > 0) {
            throw new UserException.BadInput("Use --readsPerPartitionOutput instead of --num-reducers.");
        }
        PSFilter pSFilter = new PSFilter(javaSparkContext, this.filterArgs, checkAndClearHeaderSequences);
        PSFilterLogger pSFilterFileLogger = this.filterArgs.filterMetricsFileUri != null ? new PSFilterFileLogger(getMetricsFile(), this.filterArgs.filterMetricsFileUri) : new PSFilterEmptyLogger();
        Throwable th = null;
        try {
            try {
                Tuple2<JavaRDD<GATKRead>, JavaRDD<GATKRead>> doFilter = pSFilter.doFilter(getReads(), pSFilterFileLogger);
                if (pSFilterFileLogger != null) {
                    if (0 != 0) {
                        try {
                            pSFilterFileLogger.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pSFilterFileLogger.close();
                    }
                }
                JavaRDD javaRDD = (JavaRDD) doFilter._1;
                JavaRDD javaRDD2 = (JavaRDD) doFilter._2;
                long count = javaRDD.count();
                long count2 = javaRDD2.count();
                long j = count + count2;
                pSFilter.close();
                int i = 1 + ((int) (count / this.readsPerPartition));
                int i2 = 1 + ((int) (count2 / this.readsPerPartition));
                JavaRDD<GATKRead> repartitionPairedReads = repartitionPairedReads(javaRDD, i, count);
                JavaRDD<GATKRead> repartition = javaRDD2.repartition(i2);
                PSBwaAlignerSpark pSBwaAlignerSpark = new PSBwaAlignerSpark(javaSparkContext, this.bwaArgs);
                PSBwaUtils.addReferenceSequencesToHeader(checkAndClearHeaderSequences, this.bwaArgs.referencePath, getReferenceWindowFunction());
                Broadcast<SAMFileHeader> broadcast = javaSparkContext.broadcast(checkAndClearHeaderSequences);
                JavaRDD<GATKRead> doBwaAlignment = pSBwaAlignerSpark.doBwaAlignment(repartitionPairedReads, true, broadcast);
                JavaRDD<GATKRead> doBwaAlignment2 = pSBwaAlignerSpark.doBwaAlignment(repartition, false, broadcast);
                doBwaAlignment.persist(StorageLevel.MEMORY_AND_DISK_SER());
                doBwaAlignment2.persist(StorageLevel.MEMORY_AND_DISK_SER());
                JavaRDD<GATKRead> scoreReads = new PSScorer(this.scoreArgs).scoreReads(javaSparkContext, doBwaAlignment, doBwaAlignment2, checkAndClearHeaderSequences);
                SAMFileHeader removeUnmappedHeaderSequences = PSBwaUtils.removeUnmappedHeaderSequences(checkAndClearHeaderSequences, scoreReads, this.logger);
                if (this.scoreArgs.scoreMetricsFileUri != null) {
                    PSScoreFileLogger pSScoreFileLogger = new PSScoreFileLogger(getMetricsFile(), this.scoreArgs.scoreMetricsFileUri);
                    Throwable th3 = null;
                    try {
                        try {
                            pSScoreFileLogger.logReadCounts(scoreReads);
                            if (pSScoreFileLogger != null) {
                                if (0 != 0) {
                                    try {
                                        pSScoreFileLogger.close();
                                    } catch (Throwable th4) {
                                        th3.addSuppressed(th4);
                                    }
                                } else {
                                    pSScoreFileLogger.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th5) {
                        if (pSScoreFileLogger != null) {
                            if (th3 != null) {
                                try {
                                    pSScoreFileLogger.close();
                                } catch (Throwable th6) {
                                    th3.addSuppressed(th6);
                                }
                            } else {
                                pSScoreFileLogger.close();
                            }
                        }
                        throw th5;
                    }
                }
                if (this.outputPath != null) {
                    try {
                        int max = Math.max(1, (int) (j / this.readsPerPartitionOutput));
                        ReadsSparkSink.writeReads(javaSparkContext, this.outputPath, null, scoreReads.coalesce(max, false), removeUnmappedHeaderSequences, this.shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, max, this.shardedPartsDir);
                    } catch (IOException e) {
                        throw new UserException.CouldNotCreateOutputFile(this.outputPath, "writing failed", e);
                    }
                }
                pSBwaAlignerSpark.close();
            } finally {
            }
        } catch (Throwable th7) {
            if (pSFilterFileLogger != null) {
                if (th != null) {
                    try {
                        pSFilterFileLogger.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    pSFilterFileLogger.close();
                }
            }
            throw th7;
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1762023230:
                if (implMethodName.equals("lambda$repartitionPairedReads$d51d2b29$1")) {
                    z = true;
                    break;
                }
                break;
            case 1182533742:
                if (implMethodName.equals("iterator")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("java/util/List") && serializedLambda.getImplMethodSignature().equals("()Ljava/util/Iterator;")) {
                    return (v0) -> {
                        return v0.iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/tools/spark/pathseq/PathSeqPipelineSpark") && serializedLambda.getImplMethodSignature().equals("(ILjava/util/Iterator;)Ljava/util/Iterator;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    return it -> {
                        return pairPartitionReads(it, intValue);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
