package org.broadinstitute.hellbender.tools.spark;

import com.google.common.base.Stopwatch;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import java.io.File;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.argumentcollections.IntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OptionalIntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.ReferenceInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredReadInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredReferenceInputArgumentCollection;
import org.broadinstitute.hellbender.engine.ContextShard;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.datasources.VariantsSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.spark.AddContextDataToReadSparkOptimized;
import org.broadinstitute.hellbender.engine.spark.SparkCommandLineProgram;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.transforms.bqsr.BaseRecalibratorEngineSparkWrapper;
import org.broadinstitute.hellbender.tools.walkers.bqsr.BaseRecalibrator;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.recalibration.BaseRecalibrationEngine;
import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection;
import org.broadinstitute.hellbender.utils.recalibration.RecalibrationTables;
import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;

@CommandLineProgramProperties(summary = BaseRecalibratorSparkSharded.USAGE_SUMMARY, oneLineSummary = BaseRecalibratorSparkSharded.USAGE_ONE_LINE_SUMMARY, programGroup = ReadDataManipulationProgramGroup.class)
@DocumentedFeature
@ExperimentalFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkSharded.class */
public class BaseRecalibratorSparkSharded extends SparkCommandLineProgram {
    private static final long serialVersionUID = 1;
    static final String USAGE_ONE_LINE_SUMMARY = "BaseRecalibrator on Spark (experimental sharded implementation)";
    static final String USAGE_SUMMARY = "Experimental sharded implementation of the first pass of the Base Quality Score Recalibration (BQSR) -- Generates recalibration table based on various user-specified covariates (such as read group, reported quality score, machine cycle, and nucleotide context).";

    @Argument(doc = "the known variants. Must be local.", fullName = BaseRecalibrator.KNOWN_SITES_ARG_FULL_NAME, optional = false)
    private List<String> knownVariants;

    @ArgumentCollection(doc = "all the command line arguments for BQSR and its covariates")
    private final RecalibrationArgumentCollection bqsrArgs = new RecalibrationArgumentCollection();

    @ArgumentCollection
    private final RequiredReadInputArgumentCollection readArguments = new RequiredReadInputArgumentCollection();

    @ArgumentCollection
    private final IntervalArgumentCollection intervalArgumentCollection = new OptionalIntervalArgumentCollection();

    @ArgumentCollection
    private final ReferenceInputArgumentCollection referenceArguments = new RequiredReferenceInputArgumentCollection();

    @Argument(doc = "Path to save the final recalibration tables to.", shortName = "O", fullName = "output", optional = false)
    private String outputTablesPath = null;

    @Override // org.broadinstitute.hellbender.engine.spark.SparkCommandLineProgram
    protected void runPipeline(JavaSparkContext javaSparkContext) {
        if (this.readArguments.getReadFilesNames().size() != 1) {
            throw new UserException("Sorry, we only support a single reads input for now.");
        }
        String str = this.readArguments.getReadFilesNames().get(0);
        String referenceFileName = this.referenceArguments.getReferenceFileName();
        ReferenceMultiSource referenceMultiSource = new ReferenceMultiSource(referenceFileName, BaseRecalibrationEngine.BQSR_REFERENCE_WINDOW_FUNCTION);
        SAMFileHeader header = new ReadsSparkSource(javaSparkContext, this.readArguments.getReadValidationStringency()).getHeader(str, referenceFileName);
        SAMSequenceDictionary sequenceDictionary = header.getSequenceDictionary();
        SAMSequenceDictionary referenceSequenceDictionary = referenceMultiSource.getReferenceSequenceDictionary(sequenceDictionary);
        ReadFilter fromList = ReadFilter.fromList(BaseRecalibrator.getStandardBQSRReadFilterList(), header);
        SequenceDictionaryUtils.validateDictionaries("reference", referenceSequenceDictionary, "reads", sequenceDictionary);
        Broadcast broadcast = javaSparkContext.broadcast(header);
        Broadcast broadcast2 = javaSparkContext.broadcast(referenceSequenceDictionary);
        JavaRDD<ContextShard> add = AddContextDataToReadSparkOptimized.add(javaSparkContext, this.intervalArgumentCollection.intervalsSpecified() ? this.intervalArgumentCollection.getIntervals(header.getSequenceDictionary()) : IntervalUtils.getAllIntervalsForReference(header.getSequenceDictionary()), str, VariantsSource.getVariantsList(hackilyCopyFromGCSIfNecessary(this.knownVariants)), fromList, referenceMultiSource);
        BaseRecalibratorEngineSparkWrapper baseRecalibratorEngineSparkWrapper = new BaseRecalibratorEngineSparkWrapper(broadcast, broadcast2, this.bqsrArgs);
        RecalibrationTables recalibrationTables = (RecalibrationTables) add.mapPartitions(it -> {
            return baseRecalibratorEngineSparkWrapper.apply(it);
        }).treeAggregate(new RecalibrationTables(new StandardCovariateList(this.bqsrArgs, header)), RecalibrationTables::inPlaceCombine, RecalibrationTables::inPlaceCombine, Math.max(1, (int) (Math.log(r0.partitions().size()) / Math.log(2.0d))));
        BaseRecalibrationEngine.finalizeRecalibrationTables(recalibrationTables);
        try {
            BaseRecalibratorEngineSparkWrapper.saveTextualReport(this.outputTablesPath, header, recalibrationTables, this.bqsrArgs);
        } catch (IOException e) {
            throw new UserException.CouldNotCreateOutputFile(new File(this.outputTablesPath), e);
        }
    }

    private ArrayList<String> hackilyCopyFromGCSIfNecessary(List<String> list) {
        Stopwatch createStarted = Stopwatch.createStarted();
        boolean z = false;
        ArrayList<String> arrayList = new ArrayList<>();
        for (String str : list) {
            if (BucketUtils.isCloudStorageUrl(str)) {
                if (!z) {
                    this.logger.info("(HACK): copying the GCS variant file to local just so we can read it back.");
                    z = true;
                }
                String absolutePath = IOUtils.createTempFile("knownVariants-0", ".vcf").getAbsolutePath();
                try {
                    BucketUtils.copyFile(str, absolutePath);
                    arrayList.add(absolutePath);
                } catch (IOException e) {
                    throw new UserException.CouldNotReadInputFile(str, e);
                }
            } else {
                arrayList.add(str);
            }
        }
        createStarted.stop();
        if (z) {
            this.logger.info("Copying the vcf took " + createStarted.elapsed(TimeUnit.MILLISECONDS) + " ms.");
        }
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -700536835:
                if (implMethodName.equals("inPlaceCombine")) {
                    z = false;
                    break;
                }
                break;
            case 1465223933:
                if (implMethodName.equals("lambda$runPipeline$531b64b1$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables") && serializedLambda.getImplMethodSignature().equals("(Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;)Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;")) {
                    return RecalibrationTables::inPlaceCombine;
                }
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables") && serializedLambda.getImplMethodSignature().equals("(Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;)Lorg/broadinstitute/hellbender/utils/recalibration/RecalibrationTables;")) {
                    return RecalibrationTables::inPlaceCombine;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkSharded") && serializedLambda.getImplMethodSignature().equals("(Lorg/broadinstitute/hellbender/tools/spark/transforms/bqsr/BaseRecalibratorEngineSparkWrapper;Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    BaseRecalibratorEngineSparkWrapper baseRecalibratorEngineSparkWrapper = (BaseRecalibratorEngineSparkWrapper) serializedLambda.getCapturedArg(0);
                    return it -> {
                        return baseRecalibratorEngineSparkWrapper.apply(it);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
