package org.broadinstitute.hellbender.tools.copynumber;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.programgroups.CopyNumberProgramGroup;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AbstractLocatableCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.BaselineCopyNumberCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyNumberPosteriorDistributionCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.IntegerCopyNumberSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.LocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyNumberPosteriorDistribution;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.IntervalCopyNumberGenotypingData;
import org.broadinstitute.hellbender.tools.copynumber.gcnv.GermlineCNVIntervalVariantComposer;
import org.broadinstitute.hellbender.tools.copynumber.gcnv.GermlineCNVNamingConstants;
import org.broadinstitute.hellbender.tools.copynumber.gcnv.GermlineCNVSegmentVariantComposer;
import org.broadinstitute.hellbender.tools.copynumber.gcnv.IntegerCopyNumberState;
import org.broadinstitute.hellbender.tools.funcotator.vcfOutput.VcfOutputRenderer;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;

@CommandLineProgramProperties(summary = "Postprocesses the output of GermlineCNVCaller and generates VCF files.", oneLineSummary = "Postprocesses the output of GermlineCNVCaller and generates VCF files.", programGroup = CopyNumberProgramGroup.class)
@DocumentedFeature
@BetaFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/PostprocessGermlineCNVCalls.class */
public final class PostprocessGermlineCNVCalls extends GATKTool {
    private static final Logger logger = LogManager.getLogger(PostprocessGermlineCNVCalls.class);
    public static final String SEGMENT_GERMLINE_CNV_CALLS_PYTHON_SCRIPT = "segment_gcnv_calls.py";
    public static final String CALLS_SHARD_PATH_LONG_NAME = "calls-shard-path";
    public static final String MODEL_SHARD_PATH_LONG_NAME = "model-shard-path";
    public static final String CONTIG_PLOIDY_CALLS_LONG_NAME = "contig-ploidy-calls";
    public static final String SAMPLE_INDEX_LONG_NAME = "sample-index";
    public static final String OUTPUT_INTERVALS_VCF_LONG_NAME = "output-genotyped-intervals";
    public static final String OUTPUT_SEGMENTS_VCF_LONG_NAME = "output-genotyped-segments";
    public static final String AUTOSOMAL_REF_COPY_NUMBER_LONG_NAME = "autosomal-ref-copy-number";
    public static final String ALLOSOMAL_CONTIG_LONG_NAME = "allosomal-contig";

    @Argument(doc = "List of paths to GermlineCNVCaller call directories.", fullName = CALLS_SHARD_PATH_LONG_NAME, minElements = 1)
    private List<File> unsortedCallsShardPaths;

    @Argument(doc = "List of paths to GermlineCNVCaller model directories.", fullName = MODEL_SHARD_PATH_LONG_NAME, minElements = 1)
    private List<File> unsortedModelShardPaths;

    @Argument(doc = "Path to contig-ploidy calls directory (output of DetermineGermlineContigPloidy).", fullName = "contig-ploidy-calls")
    private File contigPloidyCallsPath;

    @Argument(doc = "Sample index in the call-set (must be contained in all shards).", fullName = SAMPLE_INDEX_LONG_NAME, minValue = 0.0d)
    private int sampleIndex;

    @Argument(doc = "Contigs to treat as allosomal (i.e. choose their reference copy-number allele according to the sample karyotype).", fullName = ALLOSOMAL_CONTIG_LONG_NAME, optional = true)
    private List<String> allosomalContigList;

    @Argument(doc = "Output intervals VCF file.", fullName = OUTPUT_INTERVALS_VCF_LONG_NAME)
    private File outputIntervalsVCFFile;
    private List<SimpleIntervalCollection> sortedIntervalCollections;
    private String sampleName;
    private int numShards;
    private IntegerCopyNumberState refAutosomalIntegerCopyNumberState;
    private Set<String> allosomalContigSet;
    private List<File> sortedCallsShardPaths;
    private List<File> sortedModelShardPaths;

    @Argument(doc = "Reference copy-number on autosomal intervals.", fullName = AUTOSOMAL_REF_COPY_NUMBER_LONG_NAME, minValue = 0.0d)
    private int refAutosomalCopyNumber = 2;

    @Argument(doc = "Output segments VCF file.", fullName = OUTPUT_SEGMENTS_VCF_LONG_NAME, optional = true)
    private File outputSegmentsVCFFile = null;

    @Override // org.broadinstitute.hellbender.engine.GATKTool, org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public void onStartup() {
        super.onStartup();
        if (this.outputSegmentsVCFFile != null) {
            PythonScriptExecutor.checkPythonEnvironmentForPackage("gcnvkernel");
        }
    }

    @Override // org.broadinstitute.hellbender.engine.GATKTool
    public void onTraversalStart() {
        this.numShards = this.unsortedCallsShardPaths.size();
        Utils.validateArg(this.unsortedModelShardPaths.size() == this.numShards, "The number of input model shards must match the number of input call shards.");
        List<SimpleIntervalCollection> intervalCollectionsFromPaths = getIntervalCollectionsFromPaths(this.unsortedCallsShardPaths);
        List<SimpleIntervalCollection> intervalCollectionsFromPaths2 = getIntervalCollectionsFromPaths(this.unsortedModelShardPaths);
        SAMSequenceDictionary sequenceDictionary = ((LocatableMetadata) intervalCollectionsFromPaths.get(0).getMetadata()).getSequenceDictionary();
        Utils.validateArg(intervalCollectionsFromPaths.stream().map((v0) -> {
            return v0.getMetadata();
        }).map((v0) -> {
            return v0.getSequenceDictionary();
        }).allMatch(sAMSequenceDictionary -> {
            return sAMSequenceDictionary.equals(sequenceDictionary);
        }), "The SAM sequence dictionary is not the same for all of the call shards.");
        Utils.validateArg(intervalCollectionsFromPaths2.stream().map((v0) -> {
            return v0.getMetadata();
        }).map((v0) -> {
            return v0.getSequenceDictionary();
        }).allMatch(sAMSequenceDictionary2 -> {
            return sAMSequenceDictionary2.equals(sequenceDictionary);
        }), "The SAM sequence dictionary is either not the same for all of the model shards, or is different from the SAM sequence dictionary of calls shards.");
        List<Integer> shardedCollectionSortOrder = AbstractLocatableCollection.getShardedCollectionSortOrder(intervalCollectionsFromPaths);
        List<Integer> shardedCollectionSortOrder2 = AbstractLocatableCollection.getShardedCollectionSortOrder(intervalCollectionsFromPaths2);
        Stream<Integer> stream = shardedCollectionSortOrder.stream();
        List<File> list = this.unsortedCallsShardPaths;
        list.getClass();
        this.sortedCallsShardPaths = (List) stream.map((v1) -> {
            return r2.get(v1);
        }).collect(Collectors.toList());
        Stream<Integer> stream2 = shardedCollectionSortOrder2.stream();
        List<File> list2 = this.unsortedModelShardPaths;
        list2.getClass();
        this.sortedModelShardPaths = (List) stream2.map((v1) -> {
            return r2.get(v1);
        }).collect(Collectors.toList());
        Stream<Integer> stream3 = shardedCollectionSortOrder.stream();
        intervalCollectionsFromPaths.getClass();
        List<SimpleIntervalCollection> list3 = (List) stream3.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList());
        Stream<Integer> stream4 = shardedCollectionSortOrder2.stream();
        intervalCollectionsFromPaths2.getClass();
        Utils.validateArg(list3.equals((List) stream4.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList())), "The interval lists found in model and call shards do not match. Make sure that the calls and model paths are provided in matching order.");
        this.sortedIntervalCollections = list3;
        Set set = (Set) sequenceDictionary.getSequences().stream().map((v0) -> {
            return v0.getSequenceName();
        }).collect(Collectors.toSet());
        this.allosomalContigSet = new HashSet(this.allosomalContigList);
        if (this.allosomalContigSet.isEmpty()) {
            logger.warn(String.format("Allosomal contigs were not specified; setting ref copy-number allele to (%d) for all intervals.", Integer.valueOf(this.refAutosomalCopyNumber)));
        } else {
            Utils.validateArg(set.containsAll(this.allosomalContigSet), String.format("The specified allosomal contigs must be contained in the SAM sequence dictionary of the call-set (specified allosomal contigs: %s, all contigs: %s)", this.allosomalContigSet.stream().collect(Collectors.joining(", ", VcfOutputRenderer.START_TRANSCRIPT_DELIMITER, VcfOutputRenderer.END_TRANSCRIPT_DELIMITER)), set.stream().collect(Collectors.joining(", ", VcfOutputRenderer.START_TRANSCRIPT_DELIMITER, VcfOutputRenderer.END_TRANSCRIPT_DELIMITER))));
        }
        this.sampleName = getShardSampleName(0);
        Utils.validate(IntStream.range(1, this.numShards).mapToObj(this::getShardSampleName).allMatch(str -> {
            return str.equals(this.sampleName);
        }), "The sample name is not the same for all of the shards.");
        this.refAutosomalIntegerCopyNumberState = new IntegerCopyNumberState(this.refAutosomalCopyNumber);
    }

    @Override // org.broadinstitute.hellbender.engine.GATKTool
    public void traverse() {
        generateIntervalsVCFFileFromAllShards();
        generateSegmentsVCFFileFromAllShards();
    }

    private void generateIntervalsVCFFileFromAllShards() {
        logger.info("Generating intervals VCF file...");
        VariantContextWriter createVCFWriter = createVCFWriter(this.outputIntervalsVCFFile);
        GermlineCNVIntervalVariantComposer germlineCNVIntervalVariantComposer = new GermlineCNVIntervalVariantComposer(createVCFWriter, this.sampleName, this.refAutosomalIntegerCopyNumberState, this.allosomalContigSet);
        germlineCNVIntervalVariantComposer.composeVariantContextHeader(getDefaultToolVCFHeaderLines());
        for (int i = 0; i < this.numShards; i++) {
            logger.info(String.format("Analyzing shard %d...", Integer.valueOf(i)));
            germlineCNVIntervalVariantComposer.writeAll(getShardIntervalCopyNumberPosteriorData(i));
        }
        createVCFWriter.close();
    }

    private static List<SimpleIntervalCollection> getIntervalCollectionsFromPaths(List<File> list) {
        return (List) list.stream().map(file -> {
            return new SimpleIntervalCollection(getIntervalFileFromShardDirectory(file));
        }).collect(Collectors.toList());
    }

    private String getShardSampleName(int i) {
        File sampleNameTextFile = getSampleNameTextFile(this.sortedCallsShardPaths.get(i), this.sampleIndex);
        try {
            return new BufferedReader(new FileReader(sampleNameTextFile)).readLine();
        } catch (IOException e) {
            throw new UserException.BadInput(String.format("Could not read the sample name text file at %s.", sampleNameTextFile.getAbsolutePath()));
        }
    }

    private List<IntervalCopyNumberGenotypingData> getShardIntervalCopyNumberPosteriorData(int i) {
        File file = this.sortedCallsShardPaths.get(i);
        CopyNumberPosteriorDistributionCollection copyNumberPosteriorDistributionCollection = new CopyNumberPosteriorDistributionCollection(getSampleCopyNumberPosteriorFile(file, this.sampleIndex));
        String sampleName = copyNumberPosteriorDistributionCollection.getMetadata().getSampleName();
        Utils.validate(sampleName.equals(this.sampleName), String.format("Sample name found in the header of copy-number posterior file for shard %d different from the expected sample name (found: %s, expected: %s).", Integer.valueOf(i), sampleName, this.sampleName));
        BaselineCopyNumberCollection baselineCopyNumberCollection = new BaselineCopyNumberCollection(getSampleBaselineCopyNumberFile(file, this.sampleIndex));
        String sampleName2 = baselineCopyNumberCollection.getMetadata().getSampleName();
        Utils.validate(sampleName2.equals(this.sampleName), String.format("Sample name found in the header of baseline copy-number file for shard %d different from the expected sample name (found: %s, expected: %s).", Integer.valueOf(i), sampleName2, this.sampleName));
        List<SimpleInterval> intervals = this.sortedIntervalCollections.get(i).getIntervals();
        List<CopyNumberPosteriorDistribution> records = copyNumberPosteriorDistributionCollection.getRecords();
        Utils.validate(intervals.size() == records.size(), String.format("The number of entries in the copy-number posterior file for shard %d does not match the number of entries in the shard interval list (posterior list size: %d, interval list size: %d)", Integer.valueOf(i), Integer.valueOf(records.size()), Integer.valueOf(intervals.size())));
        List<IntegerCopyNumberState> records2 = baselineCopyNumberCollection.getRecords();
        Utils.validate(intervals.size() == records2.size(), String.format("The number of entries in the baseline copy-number file for shard %d does not match the number of entries in the shard interval list (baseline copy-number list size: %d, interval list size: %d)", Integer.valueOf(i), Integer.valueOf(records2.size()), Integer.valueOf(intervals.size())));
        return (List) IntStream.range(0, copyNumberPosteriorDistributionCollection.size()).mapToObj(i2 -> {
            return new IntervalCopyNumberGenotypingData((SimpleInterval) intervals.get(i2), (CopyNumberPosteriorDistribution) records.get(i2), (IntegerCopyNumberState) records2.get(i2));
        }).collect(Collectors.toList());
    }

    private static File getSampleCopyNumberPosteriorFile(File file, int i) {
        return Paths.get(file.getAbsolutePath(), GermlineCNVNamingConstants.SAMPLE_PREFIX + i, GermlineCNVNamingConstants.COPY_NUMBER_POSTERIOR_FILE_NAME).toFile();
    }

    private static File getSampleBaselineCopyNumberFile(File file, int i) {
        return Paths.get(file.getAbsolutePath(), GermlineCNVNamingConstants.SAMPLE_PREFIX + i, GermlineCNVNamingConstants.BASELINE_COPY_NUMBER_FILE_NAME).toFile();
    }

    private static File getSampleNameTextFile(File file, int i) {
        return Paths.get(file.getAbsolutePath(), GermlineCNVNamingConstants.SAMPLE_PREFIX + i, GermlineCNVNamingConstants.SAMPLE_NAME_TXT_FILE).toFile();
    }

    private static File getIntervalFileFromShardDirectory(File file) {
        return new File(file, "interval_list.tsv");
    }

    private void generateSegmentsVCFFileFromAllShards() {
        if (this.outputSegmentsVCFFile == null) {
            logger.info("No segments output VCF file was provided -- skipping segmentation.");
            return;
        }
        logger.info("Generating segments VCF file...");
        File createTempDir = IOUtils.createTempDir("gcnv-segmented-calls");
        if (!executeSegmentGermlineCNVCallsPythonScript(this.sampleIndex, this.contigPloidyCallsPath, this.sortedCallsShardPaths, this.sortedModelShardPaths, createTempDir)) {
            throw new UserException("Python return code was non-zero.");
        }
        IntegerCopyNumberSegmentCollection integerCopyNumberSegmentCollection = new IntegerCopyNumberSegmentCollection(getCopyNumberSegmentsFile(createTempDir, this.sampleIndex));
        String sampleName = ((SampleLocatableMetadata) integerCopyNumberSegmentCollection.getMetadata()).getSampleName();
        Utils.validate(sampleName.equals(this.sampleName), String.format("Sample name found in the header of copy-number segments file is different from the expected sample name (found: %s, expected: %s).", sampleName, this.sampleName));
        VariantContextWriter createVCFWriter = createVCFWriter(this.outputSegmentsVCFFile);
        GermlineCNVSegmentVariantComposer germlineCNVSegmentVariantComposer = new GermlineCNVSegmentVariantComposer(createVCFWriter, this.sampleName, this.refAutosomalIntegerCopyNumberState, this.allosomalContigSet);
        germlineCNVSegmentVariantComposer.composeVariantContextHeader(getDefaultToolVCFHeaderLines());
        germlineCNVSegmentVariantComposer.writeAll(integerCopyNumberSegmentCollection.getRecords());
        createVCFWriter.close();
    }

    private static boolean executeSegmentGermlineCNVCallsPythonScript(int i, File file, List<File> list, List<File> list2, File file2) {
        try {
            Utils.nonNull(file);
            Utils.nonNull(list);
            Utils.nonNull(list2);
            Utils.nonNull(file2);
            list.forEach((v0) -> {
                Utils.nonNull(v0);
            });
            list2.forEach((v0) -> {
                Utils.nonNull(v0);
            });
            PythonScriptExecutor pythonScriptExecutor = new PythonScriptExecutor(true);
            ArrayList arrayList = new ArrayList();
            arrayList.add("--ploidy_calls_path");
            arrayList.add(file.getAbsolutePath());
            arrayList.add("--model_shards");
            arrayList.addAll((Collection) list2.stream().map((v0) -> {
                return v0.getAbsolutePath();
            }).collect(Collectors.toList()));
            arrayList.add("--calls_shards");
            arrayList.addAll((Collection) list.stream().map((v0) -> {
                return v0.getAbsolutePath();
            }).collect(Collectors.toList()));
            arrayList.add("--output_path");
            arrayList.add(file2.getAbsolutePath());
            arrayList.add("--sample_index");
            arrayList.add(String.valueOf(i));
            return pythonScriptExecutor.executeScript(new Resource(SEGMENT_GERMLINE_CNV_CALLS_PYTHON_SCRIPT, PostprocessGermlineCNVCalls.class), (List<String>) null, arrayList);
        } catch (IllegalArgumentException e) {
            throw new GATKException.ShouldNeverReachHereException(e);
        }
    }

    private static File getCopyNumberSegmentsFile(File file, int i) {
        return Paths.get(file.getAbsolutePath(), GermlineCNVNamingConstants.SAMPLE_PREFIX + i, GermlineCNVNamingConstants.COPY_NUMBER_SEGMENTS_FILE_NAME).toFile();
    }
}
