package org.broadinstitute.hellbender.tools.walkers.vqsr;

import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFHeader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import java.util.stream.StreamSupport;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.barclay.argparser.Hidden;
import org.broadinstitute.hellbender.cmdline.ExomeStandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.FeatureInput;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.VariantWalker;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.GetSampleName;
import org.broadinstitute.hellbender.tools.copynumber.ModelSegments;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVFastqUtils;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.codecs.gencode.GencodeGtfFeature;
import org.broadinstitute.hellbender.utils.python.StreamingPythonScriptExecutor;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.runtime.AsynchronousStreamWriterService;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
import org.broadinstitute.hellbender.utils.variant.GATKVCFHeaderLines;
import picard.cmdline.programgroups.VariantEvaluationProgramGroup;

@CommandLineProgramProperties(summary = NeuralNetInference.USAGE_SUMMARY, oneLineSummary = NeuralNetInference.USAGE_ONE_LINE_SUMMARY, programGroup = VariantEvaluationProgramGroup.class)
@ExperimentalFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/NeuralNetInference.class */
public class NeuralNetInference extends VariantWalker {
    private static final String NL = String.format("%n", new Object[0]);
    static final String USAGE_ONE_LINE_SUMMARY = "Apply 1d Convolutional Neural Net to filter annotated variants";
    static final String USAGE_SUMMARY = "Annotate a VCF with scores from 1d Convolutional Neural Network (CNN).The CNN will look at the reference sequence and variant annotations to determine a Log Odds Score for each variant.";
    private static final int CONTIG_INDEX = 0;
    private static final int POS_INDEX = 1;
    private static final int REF_INDEX = 2;
    private static final int ALT_INDEX = 3;
    private static final int KEY_INDEX = 4;
    private FileOutputStream fifoWriter;
    private File scoreFile;

    @Argument(fullName = "output", shortName = "O", doc = "Output file")
    private String outputFile = null;

    @Argument(fullName = "architecture", shortName = "a", doc = "Neural Net architecture and weights hd5 file", optional = false)
    private String architecture = null;

    @Argument(fullName = ModelSegments.WINDOW_SIZE_LONG_NAME, shortName = "ws", doc = "Neural Net input window size", minValue = StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION, optional = true)
    private int windowSize = ReadUtils.SAM_SECOND_OF_PAIR_FLAG;

    @Advanced
    @Argument(fullName = "inference-batch-size", shortName = "ibs", doc = "Size of batches for python to do inference on.", minValue = 1.0d, maxValue = 4096.0d, optional = true)
    private int inferenceBatchSize = ReadUtils.SAM_NOT_PRIMARY_ALIGNMENT_FLAG;

    @Advanced
    @Argument(fullName = "transfer-batch-size", shortName = "tbs", doc = "Size of data to queue for python streaming.", minValue = 1.0d, maxValue = 8192.0d, optional = true)
    private int transferBatchSize = ReadUtils.SAM_READ_FAILS_VENDOR_QUALITY_CHECK_FLAG;

    @Hidden
    @Argument(fullName = "enable-journal", shortName = "journal", doc = "Enable streaming process journal.", optional = true)
    private boolean enableJournal = false;

    @Hidden
    @Argument(fullName = "keep-temp-file", shortName = "ktf", doc = "Keep the temporary file that python writes scores to.", optional = true)
    private boolean keepTempFile = false;
    final StreamingPythonScriptExecutor pythonExecutor = new StreamingPythonScriptExecutor(true);
    private AsynchronousStreamWriterService<String> asyncWriter = null;
    private List<String> batchList = new ArrayList(this.inferenceBatchSize);
    private int curBatchSize = 0;
    private int windowEnd = this.windowSize / 2;
    private int windowStart = (this.windowSize / 2) - 1;
    private boolean waitforBatchCompletion = false;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public String[] customCommandLineValidation() {
        if (this.inferenceBatchSize > this.transferBatchSize) {
            return new String[]{"Inference batch size must be less than or equal to transfer batch size."};
        }
        return null;
    }

    @Override // org.broadinstitute.hellbender.engine.GATKTool
    public boolean requiresReference() {
        return true;
    }

    @Override // org.broadinstitute.hellbender.engine.GATKTool
    public void onTraversalStart() {
        this.pythonExecutor.start(Collections.emptyList(), this.enableJournal);
        File fIFOForWrite = this.pythonExecutor.getFIFOForWrite();
        this.pythonExecutor.sendAsynchronousCommand(String.format("fifoFile = open('%s', 'r')" + NL, fIFOForWrite.getAbsolutePath()));
        try {
            this.fifoWriter = new FileOutputStream(fIFOForWrite);
            this.pythonExecutor.getAccumulatedOutput();
            this.asyncWriter = this.pythonExecutor.getAsynchronousStreamWriterService(this.fifoWriter, AsynchronousStreamWriterService.stringSerializer);
            this.batchList = new ArrayList(this.transferBatchSize);
            try {
                this.scoreFile = File.createTempFile(this.outputFile, ".temp");
                if (!this.keepTempFile) {
                    this.scoreFile.deleteOnExit();
                }
                this.pythonExecutor.sendSynchronousCommand(String.format("tempFile = open('%s', 'w+')" + NL, this.scoreFile.getAbsolutePath()));
                this.pythonExecutor.sendSynchronousCommand("from keras.models import load_model" + NL);
                this.pythonExecutor.sendSynchronousCommand("import vqsr_cnn" + NL);
                this.pythonExecutor.sendSynchronousCommand(String.format("model = load_model('%s', custom_objects=vqsr_cnn.get_metric_dict())", this.architecture) + NL);
                this.logger.info("Loaded CNN architecture:" + this.architecture);
            } catch (IOException e) {
                throw new GATKException("Error when creating temp file and initializing python executor.", e);
            }
        } catch (IOException e2) {
            throw new GATKException("Failure opening FIFO for writing", e2);
        }
    }

    @Override // org.broadinstitute.hellbender.engine.VariantWalkerBase
    public void apply(VariantContext variantContext, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
        referenceContext.setWindow(this.windowStart, this.windowEnd);
        transferToPythonViaFifo(variantContext, referenceContext);
    }

    private void transferToPythonViaFifo(VariantContext variantContext, ReferenceContext referenceContext) {
        try {
            Object[] objArr = new Object[4];
            objArr[0] = getVariantDataString(variantContext);
            objArr[1] = new String(Arrays.copyOfRange(referenceContext.getBases(), 0, this.windowSize), GetSampleName.STANDARD_ENCODING);
            objArr[2] = getVariantInfoString(variantContext);
            objArr[3] = variantContext.isSNP() ? ExomeStandardArgumentDefinitions.SNP_FILE_SHORT_NAME : variantContext.isIndel() ? "INDEL" : "OTHER";
            String format = String.format("%s\t%s\t%s\t%s\n", objArr);
            if (this.curBatchSize == this.transferBatchSize) {
                if (this.waitforBatchCompletion) {
                    this.asyncWriter.waitForPreviousBatchCompletion(1L, TimeUnit.MINUTES);
                    this.waitforBatchCompletion = false;
                    this.pythonExecutor.getAccumulatedOutput();
                }
                executePythonCommand();
                this.waitforBatchCompletion = true;
                this.curBatchSize = 0;
                this.batchList = new ArrayList(this.transferBatchSize);
            }
            this.batchList.add(format);
            this.curBatchSize++;
        } catch (UnsupportedEncodingException e) {
            throw new GATKException("Trying to make string from reference, but unsupported encoding UTF-8.", e);
        }
    }

    private String getVariantDataString(VariantContext variantContext) {
        return String.format("%s\t%d\t%s\t%s", variantContext.getContig(), Integer.valueOf(variantContext.getStart()), variantContext.getReference().getBaseString(), variantContext.getAlternateAlleles().toString());
    }

    private String getVariantInfoString(VariantContext variantContext) {
        String str = "";
        for (String str2 : variantContext.getAttributes().keySet()) {
            str = str + str2 + FeatureInput.FEATURE_ARGUMENT_KEY_VALUE_SEPARATOR + variantContext.getAttribute(str2).toString().replace(GencodeGtfFeature.EXTRA_FIELD_KEY_VALUE_SPLITTER, "").replace("[", "").replace("]", "") + ";";
        }
        return str;
    }

    @Override // org.broadinstitute.hellbender.engine.GATKTool
    public Object onTraversalSuccess() {
        if (this.waitforBatchCompletion) {
            this.asyncWriter.waitForPreviousBatchCompletion(1L, TimeUnit.MINUTES);
            this.pythonExecutor.getAccumulatedOutput();
        }
        if (this.curBatchSize > 0) {
            executePythonCommand();
            this.asyncWriter.waitForPreviousBatchCompletion(1L, TimeUnit.MINUTES);
            this.pythonExecutor.getAccumulatedOutput();
        }
        this.pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL);
        this.pythonExecutor.sendSynchronousCommand("fifoFile.close()" + NL);
        this.pythonExecutor.terminate();
        writeOutputVCFWithScores();
        return true;
    }

    private void executePythonCommand() {
        this.pythonExecutor.sendAsynchronousCommand(String.format("vqsr_cnn.score_and_write_batch(model, tempFile, fifoFile, %d, %d)", Integer.valueOf(this.curBatchSize), Integer.valueOf(this.inferenceBatchSize)) + NL);
        this.asyncWriter.startAsynchronousBatchWrite(this.batchList);
    }

    private void writeOutputVCFWithScores() {
        try {
            Scanner scanner = new Scanner(this.scoreFile);
            Throwable th = null;
            try {
                VariantContextWriter createVCFWriter = createVCFWriter(new File(this.outputFile));
                Throwable th2 = null;
                try {
                    try {
                        scanner.useDelimiter("\\n");
                        writeVCFHeader(createVCFWriter);
                        StreamSupport.stream(getSpliteratorForDrivingVariants(), false).filter(makeVariantFilter()).forEach(variantContext -> {
                            String nextLine = scanner.nextLine();
                            String[] split = nextLine.split(SVFastqUtils.HEADER_FIELD_SEPARATOR_REGEXP);
                            if (!variantContext.getContig().equals(split[0]) || !Integer.toString(variantContext.getStart()).equals(split[1]) || !variantContext.getReference().getBaseString().equals(split[2]) || !variantContext.getAlternateAlleles().toString().equals(split[3])) {
                                throw new GATKException(("Score file out of sync with original VCF. Score file has:" + nextLine) + "\n But VCF has:" + variantContext.toStringWithoutGenotypes());
                            }
                            VariantContextBuilder variantContextBuilder = new VariantContextBuilder(variantContext);
                            variantContextBuilder.attribute(GATKVCFConstants.CNN_1D_KEY, split[4]);
                            createVCFWriter.add(variantContextBuilder.make());
                        });
                        if (createVCFWriter != null) {
                            if (0 != 0) {
                                try {
                                    createVCFWriter.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                createVCFWriter.close();
                            }
                        }
                        if (scanner != null) {
                            if (0 != 0) {
                                try {
                                    scanner.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scanner.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (createVCFWriter != null) {
                        if (th2 != null) {
                            try {
                                createVCFWriter.close();
                            } catch (Throwable th6) {
                                th2.addSuppressed(th6);
                            }
                        } else {
                            createVCFWriter.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (IOException e) {
            throw new GATKException("Error when trying to write annotated VCF.", e);
        }
    }

    private void writeVCFHeader(VariantContextWriter variantContextWriter) {
        VCFHeader headerForVariants = getHeaderForVariants();
        HashSet hashSet = new HashSet(headerForVariants.getMetaDataInSortedOrder());
        hashSet.add(GATKVCFHeaderLines.getInfoLine(GATKVCFConstants.CNN_1D_KEY));
        VariantRecalibrationUtils.addVQSRStandardHeaderLines(hashSet);
        TreeSet treeSet = new TreeSet();
        treeSet.addAll(headerForVariants.getGenotypeSamples());
        hashSet.addAll(getDefaultToolVCFHeaderLines());
        variantContextWriter.writeHeader(new VCFHeader(hashSet, treeSet));
    }
}
