package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLinePluginDescriptor;
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.IntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OptionalIntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OptionalReadInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OptionalReferenceInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.ReadInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.ReferenceInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredIntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredReadInputArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredReferenceInputArgumentCollection;
import org.broadinstitute.hellbender.engine.FeatureManager;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.datasources.ReferenceWindowFunctions;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.SerializableFunction;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;

/* loaded from: input_file:org/broadinstitute/hellbender/engine/spark/GATKSparkTool.class */
public abstract class GATKSparkTool extends SparkCommandLineProgram {
    private static final long serialVersionUID = 1;
    public static final String BAM_PARTITION_SIZE_LONG_NAME = "bam-partition-size";
    public static final String NUM_REDUCERS_LONG_NAME = "num-reducers";
    public static final String SHARDED_OUTPUT_LONG_NAME = "sharded-output";

    @ArgumentCollection
    public final ReferenceInputArgumentCollection referenceArguments;

    @ArgumentCollection
    public final ReadInputArgumentCollection readArguments;

    @ArgumentCollection
    protected IntervalArgumentCollection intervalArgumentCollection;

    @Argument(doc = "maximum number of bytes to read from a file into each partition of reads. Setting this higher will result in fewer partitions. Note that this will not be equal to the size of the partition in memory. Defaults to 0, which uses the default split size (determined by the Hadoop input format, typically the size of one HDFS block).", fullName = BAM_PARTITION_SIZE_LONG_NAME, optional = true)
    protected long bamPartitionSplitSize;

    @Argument(fullName = StandardArgumentDefinitions.DISABLE_SEQUENCE_DICT_VALIDATION_NAME, shortName = StandardArgumentDefinitions.DISABLE_SEQUENCE_DICT_VALIDATION_NAME, doc = "If specified, do not check the sequence dictionaries from our inputs for compatibility. Use at your own risk!", optional = true)
    private boolean disableSequenceDictionaryValidation;

    @Argument(doc = "For tools that write an output, write the output in multiple pieces (shards)", fullName = SHARDED_OUTPUT_LONG_NAME, optional = true)
    protected boolean shardedOutput;

    @Argument(doc = "For tools that shuffle data or write an output, sets the number of reducers. Defaults to 0, which gives one partition per 10MB of input.", fullName = NUM_REDUCERS_LONG_NAME, optional = true)
    protected int numReducers;
    private ReadsSparkSource readsSource;
    private SAMFileHeader readsHeader;
    private String readInput;
    private ReferenceMultiSource referenceSource;
    private SAMSequenceDictionary referenceDictionary;
    private List<SimpleInterval> intervals;
    protected FeatureManager features;

    public GATKSparkTool() {
        this.referenceArguments = requiresReference() ? new RequiredReferenceInputArgumentCollection() : new OptionalReferenceInputArgumentCollection();
        this.readArguments = requiresReads() ? new RequiredReadInputArgumentCollection() : new OptionalReadInputArgumentCollection();
        this.intervalArgumentCollection = requiresIntervals() ? new RequiredIntervalArgumentCollection() : new OptionalIntervalArgumentCollection();
        this.bamPartitionSplitSize = 0L;
        this.disableSequenceDictionaryValidation = false;
        this.shardedOutput = false;
        this.numReducers = 0;
    }

    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public List<? extends CommandLinePluginDescriptor<?>> getPluginDescriptors() {
        return Collections.singletonList(new GATKReadFilterPluginDescriptor(getDefaultReadFilters()));
    }

    public boolean requiresReference() {
        return false;
    }

    public boolean requiresReads() {
        return false;
    }

    public boolean requiresIntervals() {
        return false;
    }

    public final boolean hasReference() {
        return this.referenceSource != null;
    }

    public final boolean hasReads() {
        return this.readsSource != null;
    }

    public final boolean hasIntervals() {
        return this.intervals != null;
    }

    public SerializableFunction<GATKRead, SimpleInterval> getReferenceWindowFunction() {
        return ReferenceWindowFunctions.IDENTITY_FUNCTION;
    }

    public SAMSequenceDictionary getBestAvailableSequenceDictionary() {
        if (hasReference()) {
            return this.referenceDictionary;
        }
        if (hasReads()) {
            return this.readsHeader.getSequenceDictionary();
        }
        return null;
    }

    public SAMSequenceDictionary getReferenceSequenceDictionary() {
        return this.referenceDictionary;
    }

    public SAMFileHeader getHeaderForReads() {
        return this.readsHeader;
    }

    public JavaRDD<GATKRead> getReads() {
        ReadFilter makeReadFilter = makeReadFilter();
        return getUnfilteredReads().filter(gATKRead -> {
            return Boolean.valueOf(makeReadFilter.test(gATKRead));
        });
    }

    public JavaRDD<GATKRead> getUnfilteredReads() {
        TraversalParameters traversalParameters = this.intervalArgumentCollection.intervalsSpecified() ? this.intervalArgumentCollection.getTraversalParameters(getHeaderForReads().getSequenceDictionary()) : hasIntervals() ? new TraversalParameters(getIntervals(), false) : null;
        if (this.readInput.endsWith(".adam")) {
            try {
                return this.readsSource.getADAMReads(this.readInput, traversalParameters, getHeaderForReads());
            } catch (IOException e) {
                throw new UserException("Failed to read ADAM file " + this.readInput, e);
            }
        }
        if (!hasCramInput() || hasReference()) {
            return this.readsSource.getParallelReads(this.readInput, hasReference() ? this.referenceArguments.getReferenceFileName() : null, traversalParameters, this.bamPartitionSplitSize);
        }
        throw new UserException.MissingReference("A reference file is required when using CRAM files.");
    }

    public void writeReads(JavaSparkContext javaSparkContext, String str, JavaRDD<GATKRead> javaRDD) {
        writeReads(javaSparkContext, str, javaRDD, this.readsHeader);
    }

    public void writeReads(JavaSparkContext javaSparkContext, String str, JavaRDD<GATKRead> javaRDD, SAMFileHeader sAMFileHeader) {
        try {
            ReadsSparkSink.writeReads(javaSparkContext, str, hasReference() ? this.referenceArguments.getReferencePath().toAbsolutePath().toUri().toString() : null, javaRDD, sAMFileHeader, this.shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, getRecommendedNumReducers());
        } catch (IOException e) {
            throw new UserException.CouldNotCreateOutputFile(str, "writing failed", e);
        }
    }

    public int getRecommendedNumReducers() {
        return this.numReducers != 0 ? this.numReducers : 1 + ((int) (BucketUtils.dirSize(getReadSourceName()) / getTargetPartitionSize()));
    }

    public int getTargetPartitionSize() {
        return 10485760;
    }

    private boolean hasCramInput() {
        return this.readArguments.getReadFiles().stream().anyMatch(IOUtils::isCramFile);
    }

    public ReadFilter makeReadFilter() {
        return makeReadFilter(getHeaderForReads());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ReadFilter makeReadFilter(SAMFileHeader sAMFileHeader) {
        return ((GATKReadFilterPluginDescriptor) getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class)).getMergedReadFilter(sAMFileHeader);
    }

    public List<ReadFilter> getDefaultReadFilters() {
        return Arrays.asList(new WellformedReadFilter());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getReadSourceName() {
        return this.readInput;
    }

    public ReferenceMultiSource getReference() {
        return this.referenceSource;
    }

    public List<SimpleInterval> getIntervals() {
        return this.intervals;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.broadinstitute.hellbender.engine.spark.SparkCommandLineProgram
    public void runPipeline(JavaSparkContext javaSparkContext) {
        initializeToolInputs(javaSparkContext);
        validateToolInputs();
        runTool(javaSparkContext);
    }

    private void initializeToolInputs(JavaSparkContext javaSparkContext) {
        initializeReference();
        initializeReads(javaSparkContext);
        initializeFeatures();
        initializeIntervals();
    }

    private void initializeReads(JavaSparkContext javaSparkContext) {
        if (this.readArguments.getReadFilesNames().isEmpty()) {
            return;
        }
        if (this.readArguments.getReadFilesNames().size() != 1) {
            throw new UserException("Sorry, we only support a single reads input for spark tools for now.");
        }
        this.readInput = this.readArguments.getReadFilesNames().get(0);
        this.readsSource = new ReadsSparkSource(javaSparkContext, this.readArguments.getReadValidationStringency());
        this.readsHeader = this.readsSource.getHeader(this.readInput, hasReference() ? this.referenceArguments.getReferenceFileName() : null);
    }

    private void initializeReference() {
        String referenceFileName = this.referenceArguments.getReferenceFileName();
        if (referenceFileName != null) {
            this.referenceSource = new ReferenceMultiSource(referenceFileName, getReferenceWindowFunction());
            this.referenceDictionary = this.referenceSource.getReferenceSequenceDictionary(this.readsHeader != null ? this.readsHeader.getSequenceDictionary() : null);
            if (this.referenceDictionary == null) {
                throw new UserException.MissingReferenceDictFile(referenceFileName);
            }
        }
    }

    void initializeFeatures() {
        this.features = new FeatureManager(this);
        if (this.features.isEmpty()) {
            this.features = null;
        }
    }

    private void initializeIntervals() {
        if (this.intervalArgumentCollection.intervalsSpecified()) {
            SAMSequenceDictionary bestAvailableSequenceDictionary = getBestAvailableSequenceDictionary();
            if (bestAvailableSequenceDictionary == null) {
                throw new UserException("We require at least one input source that has a sequence dictionary (reference or reads) when intervals are specified");
            }
            this.intervals = this.intervalArgumentCollection.getIntervals(bestAvailableSequenceDictionary);
        }
        this.intervals = editIntervals(this.intervals);
    }

    protected List<SimpleInterval> editIntervals(List<SimpleInterval> list) {
        return list;
    }

    private void validateToolInputs() {
        if (!this.disableSequenceDictionaryValidation && hasReference() && hasReads()) {
            if (hasCramInput()) {
                SequenceDictionaryUtils.validateCRAMDictionaryAgainstReference(this.referenceDictionary, this.readsHeader.getSequenceDictionary());
            } else {
                SequenceDictionaryUtils.validateDictionaries("reference", this.referenceDictionary, "reads", this.readsHeader.getSequenceDictionary());
            }
        }
    }

    protected abstract void runTool(JavaSparkContext javaSparkContext);

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1458868757:
                if (implMethodName.equals("lambda$getReads$e4b35a40$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/broadinstitute/hellbender/engine/spark/GATKSparkTool") && serializedLambda.getImplMethodSignature().equals("(Lorg/broadinstitute/hellbender/engine/filters/ReadFilter;Lorg/broadinstitute/hellbender/utils/read/GATKRead;)Ljava/lang/Boolean;")) {
                    ReadFilter readFilter = (ReadFilter) serializedLambda.getCapturedArg(0);
                    return gATKRead -> {
                        return Boolean.valueOf(readFilter.test(gATKRead));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
