package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;

import com.google.common.primitives.Doubles;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.vcf.VCFFilterHeaderLine;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLineType;
import htsjdk.variant.vcf.VCFInfoHeaderLine;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Triple;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsScorer;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonVariantAnnotationsScorer;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModelBackend;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@CommandLineProgramProperties(summary = "Scores variant calls in a VCF file based on site-level annotations using a previously trained model.", oneLineSummary = "Scores variant calls in a VCF file based on site-level annotations using a previously trained model", programGroup = VariantFilteringProgramGroup.class)
@DocumentedFeature
@BetaFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ScoreVariantAnnotations.class */
public class ScoreVariantAnnotations extends LabeledVariantAnnotationsWalker {
    public static final String MODEL_PREFIX_LONG_NAME = "model-prefix";
    public static final String MODEL_BACKEND_LONG_NAME = "model-backend";
    public static final String PYTHON_SCRIPT_LONG_NAME = "python-script";
    public static final String SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME = "snp-calibration-sensitivity-threshold";
    public static final String INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME = "indel-calibration-sensitivity-threshold";
    public static final String SNP_KEY_LONG_NAME = "snp-key";
    public static final String SCORE_KEY_LONG_NAME = "score-key";
    public static final String CALIBRATION_SENSITIVITY_KEY_LONG_NAME = "calibration-sensitivity-key";
    public static final String LOW_SCORE_FILTER_NAME_LONG_NAME = "low-score-filter-name";
    public static final String DOUBLE_FORMAT_LONG_NAME = "double-format";
    public static final String DEFAULT_SNP_KEY = "snp";
    public static final String DEFAULT_SCORE_KEY = "SCORE";
    public static final String DEFAULT_CALIBRATION_SENSITIVITY_KEY = "CALIBRATION_SENSITIVITY";
    public static final String DEFAULT_LOW_SCORE_FILTER_NAME = "LOW_SCORE";
    public static final String DEFAULT_DOUBLE_FORMAT = "%.4f";
    public static final String SCORES_HDF5_SUFFIX = ".scores.hdf5";

    @Argument(fullName = MODEL_PREFIX_LONG_NAME, doc = "Prefix for model files. This should be identical to the output prefix specified in TrainVariantAnnotationsModel.")
    private String modelPrefix;

    @Argument(fullName = "python-script", doc = "Python script used for specifying a custom scoring backend. If provided, model-backend must also be set to PYTHON_SCRIPT.", optional = true)
    private File pythonScriptFile;

    @Argument(fullName = SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, doc = "If specified, SNPs with scores corresponding to a calibration sensitivity that is greater than or equal to this threshold will be hard filtered.", optional = true, minValue = 0.0d, maxValue = 1.0d)
    private Double snpCalibrationSensitivityThreshold;

    @Argument(fullName = INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, doc = "If specified, indels with scores corresponding to a calibration sensitivity that is greater than or equal to this threshold will be hard filtered.", optional = true, minValue = 0.0d, maxValue = 1.0d)
    private Double indelCalibrationSensitivityThreshold;
    private File outputScoresFile;
    private Iterator<Double> scoresIterator;
    private Iterator<Boolean> isSNPIterator;
    private VariantAnnotationsScorer snpScorer;
    private VariantAnnotationsScorer indelScorer;
    private Function<Double, Double> snpCalibrationSensitivityConverter;
    private Function<Double, Double> indelCalibrationSensitivityConverter;

    @Argument(fullName = "model-backend", doc = "Backend to use for scoring. JAVA_BGMM will use a pure Java implementation (ported from Python scikit-learn) of the Bayesian Gaussian Mixture Model. PYTHON_IFOREST will use the Python scikit-learn implementation of the IsolationForest method and will require that the corresponding Python dependencies are present in the environment. PYTHON_SCRIPT will use the script specified by the python-script argument. See the tool documentation for more details.")
    private VariantAnnotationsModelBackend modelBackend = VariantAnnotationsModelBackend.PYTHON_IFOREST;

    @Argument(fullName = SNP_KEY_LONG_NAME, doc = "Annotation flag to use for labeling sites as SNPs in output. Set this to \"null\" to omit these labels.")
    private String snpKey = "snp";

    @Argument(fullName = SCORE_KEY_LONG_NAME, doc = "Annotation key to use for score values in output.")
    private String scoreKey = DEFAULT_SCORE_KEY;

    @Argument(fullName = CALIBRATION_SENSITIVITY_KEY_LONG_NAME, doc = "Annotation key to use for calibration-sensitivity values in output.")
    private String calibrationSensitivityKey = DEFAULT_CALIBRATION_SENSITIVITY_KEY;

    @Argument(fullName = LOW_SCORE_FILTER_NAME_LONG_NAME, doc = "Name to use for low-score filter in output.")
    private String lowScoreFilterName = DEFAULT_LOW_SCORE_FILTER_NAME;

    @Argument(fullName = DOUBLE_FORMAT_LONG_NAME, doc = "Format string to use for formatting score and calibration-sensitivity values in output.")
    private String doubleFormat = DEFAULT_DOUBLE_FORMAT;

    @Override // org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker, org.broadinstitute.hellbender.engine.MultiplePassVariantWalker
    protected int numberOfPasses() {
        return 2;
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker
    public void afterOnTraversalStart() {
        Utils.nonNull(this.scoreKey);
        Utils.nonNull(this.calibrationSensitivityKey);
        Utils.nonNull(this.lowScoreFilterName);
        Utils.nonNull(this.doubleFormat);
        switch (this.modelBackend) {
            case JAVA_BGMM:
                Utils.validateArg(this.pythonScriptFile == null, "Python script should not be provided when using JAVA_BGMM backend.");
                this.logger.info("Running in JAVA_BGMM mode...");
                this.snpScorer = deserializeScorerFromSerFiles(VariantType.SNP);
                this.indelScorer = deserializeScorerFromSerFiles(VariantType.INDEL);
                break;
            case PYTHON_IFOREST:
                Utils.validateArg(this.pythonScriptFile == null, "Python script should not be provided when using PYTHON_IFOREST backend.");
                this.pythonScriptFile = IOUtils.writeTempResource(new Resource(TrainVariantAnnotationsModel.ISOLATION_FOREST_PYTHON_SCRIPT, TrainVariantAnnotationsModel.class));
                PythonScriptExecutor.checkPythonEnvironmentForPackage("argparse");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("h5py");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("numpy");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("sklearn");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("dill");
                this.logger.info("Running in PYTHON_IFOREST mode...");
                this.snpScorer = deserializeScorerFromPklFiles(VariantType.SNP);
                this.indelScorer = deserializeScorerFromPklFiles(VariantType.INDEL);
                break;
            case PYTHON_SCRIPT:
                IOUtils.canReadFile(this.pythonScriptFile);
                this.logger.info("Running in PYTHON_SCRIPT mode...");
                this.snpScorer = deserializeScorerFromPklFiles(VariantType.SNP);
                this.indelScorer = deserializeScorerFromPklFiles(VariantType.INDEL);
                break;
            default:
                throw new GATKException.ShouldNeverReachHereException("Unknown model-backend mode.");
        }
        if (this.snpScorer == null && this.indelScorer == null) {
            throw new UserException.BadInput(String.format("At least one serialized scorer must be present in the model files with the prefix %s.", this.modelPrefix));
        }
        if (this.variantTypesToExtract.contains(VariantType.SNP) && this.snpScorer == null) {
            throw new UserException.BadInput(String.format("SNPs were indicated for extraction via the %s argument, but no serialized SNP scorer was available in the model files with the prefix.", "mode", this.modelPrefix));
        }
        if (this.variantTypesToExtract.contains(VariantType.INDEL) && this.indelScorer == null) {
            throw new UserException.BadInput(String.format("INDELs were indicated for extraction via the %s argument, but no serialized INDEL scorer was available in the model files with the prefix.", "mode", this.modelPrefix));
        }
        this.snpCalibrationSensitivityConverter = readCalibrationScoresAndCreateConverter(VariantType.SNP);
        this.indelCalibrationSensitivityConverter = readCalibrationScoresAndCreateConverter(VariantType.INDEL);
        if (this.snpCalibrationSensitivityConverter == null && this.snpCalibrationSensitivityThreshold != null) {
            throw new UserException.BadInput(String.format("The %s argument was specified, but no SNP calibration scores were provided in the model files with the prefix %s.", SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, this.modelPrefix));
        }
        if (this.indelCalibrationSensitivityConverter == null && this.indelCalibrationSensitivityThreshold != null) {
            throw new UserException.BadInput(String.format("The %s argument was specified, but no INDEL calibration scores were provided in the model files with the prefix %s.", INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, this.modelPrefix));
        }
        this.outputScoresFile = new File(this.outputPrefix + ".scores.hdf5");
        CopyNumberArgumentValidationUtils.validateOutputFiles(this.outputScoresFile);
    }

    @Override // org.broadinstitute.hellbender.engine.MultiplePassVariantWalker
    protected void nthPassApply(VariantContext variantContext, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext, int i) {
        List<Triple<List<Allele>, VariantType, TreeSet<String>>> extractVariantMetadata = extractVariantMetadata(variantContext, featureContext, true);
        boolean z = !extractVariantMetadata.isEmpty();
        if (i == 0 && z) {
            addExtractedVariantToData(this.data, variantContext, extractVariantMetadata);
        }
        if (i == 1) {
            if (z) {
                writeExtractedVariantToVCF(variantContext, extractVariantMetadata);
            } else {
                this.vcfWriter.add(variantContext);
            }
        }
    }

    @Override // org.broadinstitute.hellbender.engine.MultiplePassVariantWalker
    protected void afterNthPass(int i) {
        if (i == 0) {
            writeAnnotationsToHDF5();
            if (this.data.size() > 0) {
                this.data.clear();
                readAnnotationsAndWriteScoresToHDF5();
                this.scoresIterator = Arrays.stream(VariantAnnotationsScorer.readScores(this.outputScoresFile)).iterator();
                this.isSNPIterator = LabeledVariantAnnotationsData.readLabel(this.outputAnnotationsFile, "snp").iterator();
            } else {
                this.logger.warn("Found no input variants for scoring. This may be because the specified genomic region contains no input variants of the requested type(s). The scores HDF5 file will not be generated.");
                this.scoresIterator = Collections.emptyIterator();
                this.isSNPIterator = Collections.emptyIterator();
            }
        }
        if (i == 1) {
            if (this.scoresIterator.hasNext()) {
                throw new IllegalStateException("Traversals of scores and variants (or alleles, in allele-specific mode) were not correctly synchronized.");
            }
            if (this.vcfWriter != null) {
                this.vcfWriter.close();
            }
        }
    }

    private VariantAnnotationsScorer deserializeScorerFromPklFiles(VariantType variantType) {
        File file = new File(this.modelPrefix + ("." + variantType.toString().toLowerCase()) + ".scorer.pkl");
        if (file.canRead()) {
            return new PythonVariantAnnotationsScorer(this.pythonScriptFile, file);
        }
        return null;
    }

    private VariantAnnotationsScorer deserializeScorerFromSerFiles(VariantType variantType) {
        File file = new File(this.modelPrefix + ("." + variantType.toString().toLowerCase()) + ".bgmmScorer.ser");
        if (file.canRead()) {
            return BGMMVariantAnnotationsScorer.deserialize(file);
        }
        return null;
    }

    private Function<Double, Double> readCalibrationScoresAndCreateConverter(VariantType variantType) {
        File file = new File(this.modelPrefix + ("." + variantType.toString().toLowerCase()) + ".calibrationScores.hdf5");
        if (file.canRead()) {
            return VariantAnnotationsScorer.createScoreToCalibrationSensitivityConverter(VariantAnnotationsScorer.readScores(file));
        }
        return null;
    }

    private void readAnnotationsAndWriteScoresToHDF5() {
        List<String> readAnnotationNames = LabeledVariantAnnotationsData.readAnnotationNames(this.outputAnnotationsFile);
        List<Boolean> readLabel = LabeledVariantAnnotationsData.readLabel(this.outputAnnotationsFile, "snp");
        double[][] readAnnotations = LabeledVariantAnnotationsData.readAnnotations(this.outputAnnotationsFile);
        ArrayList arrayList = new ArrayList(Collections.nCopies(readAnnotations.length, Double.valueOf(Double.NaN)));
        if (this.variantTypesToExtract.contains(VariantType.SNP)) {
            this.logger.info("Scoring SNP variants...");
            scoreVariantTypeAndSetElementsOfAllScores(readAnnotationNames, readAnnotations, readLabel, this.snpScorer, arrayList);
        }
        if (this.variantTypesToExtract.contains(VariantType.INDEL)) {
            this.logger.info("Scoring INDEL variants...");
            scoreVariantTypeAndSetElementsOfAllScores(readAnnotationNames, readAnnotations, (List) readLabel.stream().map(bool -> {
                return Boolean.valueOf(!bool.booleanValue());
            }).collect(Collectors.toList()), this.indelScorer, arrayList);
        }
        VariantAnnotationsScorer.writeScores(this.outputScoresFile, Doubles.toArray(arrayList));
        this.logger.info(String.format("Scores written to %s.", this.outputScoresFile.getAbsolutePath()));
    }

    private static void scoreVariantTypeAndSetElementsOfAllScores(List<String> list, double[][] dArr, List<Boolean> list2, VariantAnnotationsScorer variantAnnotationsScorer, List<Double> list3) {
        File subsetAnnotationsToTemporaryFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(list, dArr, list2);
        File createTempFile = IOUtils.createTempFile("temp", SCORES_HDF5_SUFFIX);
        variantAnnotationsScorer.score(subsetAnnotationsToTemporaryFile, createTempFile);
        Iterator<Double> it = Arrays.stream(VariantAnnotationsScorer.readScores(createTempFile)).iterator();
        IntStream range = IntStream.range(0, list3.size());
        Objects.requireNonNull(list2);
        range.filter(list2::get).forEach(i -> {
            list3.set(i, (Double) it.next());
        });
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker
    void writeExtractedVariantToVCF(VariantContext variantContext, List<Allele> list, Set<String> set) {
        VariantContextBuilder variantContextBuilder = new VariantContextBuilder(variantContext);
        set.forEach(str -> {
            variantContextBuilder.attribute(str, true);
        });
        List singletonList = this.useASAnnotations ? (List) list.stream().map(allele -> {
            return this.scoresIterator.next();
        }).collect(Collectors.toList()) : Collections.singletonList(this.scoresIterator.next());
        double doubleValue = ((Double) Collections.max(singletonList)).doubleValue();
        int indexOf = singletonList.indexOf(Double.valueOf(doubleValue));
        variantContextBuilder.attribute(this.scoreKey, formatDouble(doubleValue));
        boolean booleanValue = ((Boolean) (this.useASAnnotations ? (List) list.stream().map(allele2 -> {
            return this.isSNPIterator.next();
        }).collect(Collectors.toList()) : Collections.singletonList(this.isSNPIterator.next())).get(indexOf)).booleanValue();
        if (this.snpKey != null) {
            variantContextBuilder.attribute(this.snpKey, Boolean.valueOf(booleanValue));
        }
        Function<Double, Double> function = booleanValue ? this.snpCalibrationSensitivityConverter : this.indelCalibrationSensitivityConverter;
        if (function != null) {
            double doubleValue2 = function.apply(Double.valueOf(doubleValue)).doubleValue();
            variantContextBuilder.attribute(this.calibrationSensitivityKey, formatDouble(doubleValue2));
            Double d = booleanValue ? this.snpCalibrationSensitivityThreshold : this.indelCalibrationSensitivityThreshold;
            if (d != null && doubleValue2 >= d.doubleValue()) {
                variantContextBuilder.filter(this.lowScoreFilterName);
            }
        }
        this.vcfWriter.add(variantContextBuilder.make());
    }

    private String formatDouble(double d) {
        return String.format(this.doubleFormat, Double.valueOf(d));
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker
    VCFHeader constructVCFHeader(List<String> list) {
        VCFHeader headerForVariants = getHeaderForVariants();
        HashSet hashSet = new HashSet(headerForVariants.getMetaDataInSortedOrder());
        hashSet.add(new VCFInfoHeaderLine(this.scoreKey, 1, VCFHeaderLineType.Float, "Score according to the model applied by ScoreVariantAnnotations"));
        hashSet.add(new VCFInfoHeaderLine(this.calibrationSensitivityKey, 1, VCFHeaderLineType.Float, String.format("Calibration sensitivity corresponding to the value of %s", this.scoreKey)));
        hashSet.add(new VCFFilterHeaderLine(this.lowScoreFilterName, "Low score (corresponding to high calibration sensitivity)"));
        if (this.snpKey != null) {
            hashSet.add(new VCFInfoHeaderLine(this.snpKey, 1, VCFHeaderLineType.Flag, "This site was considered a SNP during filtering"));
        }
        hashSet.addAll((Collection) list.stream().map(str -> {
            return new VCFInfoHeaderLine(str, 1, VCFHeaderLineType.Flag, String.format(LabeledVariantAnnotationsWalker.RESOURCE_LABEL_INFO_HEADER_LINE_FORMAT_STRING, str));
        }).collect(Collectors.toList()));
        hashSet.addAll(getDefaultToolVCFHeaderLines());
        return new VCFHeader(hashSet, headerForVariants.getGenotypeSamples());
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker, org.broadinstitute.hellbender.engine.GATKTool
    public Object onTraversalSuccess() {
        this.logger.info(String.format("%s complete.", getClass().getSimpleName()));
        return null;
    }
}
