package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;

import com.google.common.collect.Streams;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsModel;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsScorer;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonVariantAnnotationsModel;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonVariantAnnotationsScorer;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModel;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModelBackend;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@CommandLineProgramProperties(summary = "Trains a model for scoring variant calls based on site-level annotations.", oneLineSummary = "Trains a model for scoring variant calls based on site-level annotations", programGroup = VariantFilteringProgramGroup.class)
@DocumentedFeature
@BetaFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/TrainVariantAnnotationsModel.class */
public final class TrainVariantAnnotationsModel extends CommandLineProgram {
    public static final String MODE_LONG_NAME = "mode";
    public static final String ANNOTATIONS_HDF5_LONG_NAME = "annotations-hdf5";
    public static final String UNLABELED_ANNOTATIONS_HDF5_LONG_NAME = "unlabeled-annotations-hdf5";
    public static final String MODEL_BACKEND_LONG_NAME = "model-backend";
    public static final String PYTHON_SCRIPT_LONG_NAME = "python-script";
    public static final String HYPERPARAMETERS_JSON_LONG_NAME = "hyperparameters-json";
    public static final String ISOLATION_FOREST_PYTHON_SCRIPT = "isolation-forest.py";
    public static final String ISOLATION_FOREST_HYPERPARAMETERS_JSON = "isolation-forest-hyperparameters.json";
    public static final String TRAINING_SCORES_HDF5_SUFFIX = ".trainingScores.hdf5";
    public static final String CALIBRATION_SCORES_HDF5_SUFFIX = ".calibrationScores.hdf5";
    public static final String UNLABELED_SCORES_HDF5_SUFFIX = ".unlabeledScores.hdf5";

    @Argument(fullName = ANNOTATIONS_HDF5_LONG_NAME, doc = "HDF5 file containing annotations extracted with ExtractVariantAnnotations.")
    private File inputAnnotationsFile;

    @Argument(fullName = UNLABELED_ANNOTATIONS_HDF5_LONG_NAME, doc = "HDF5 file containing annotations extracted with ExtractVariantAnnotations. If specified, a positive-unlabeled modeling approach will be used; otherwise, a positive-only modeling approach will be used.", optional = true)
    private File inputUnlabeledAnnotationsFile;

    @Argument(fullName = "python-script", doc = "Python script used for specifying a custom scoring backend. If provided, model-backend must also be set to PYTHON_SCRIPT.", optional = true)
    private File pythonScriptFile;

    @Argument(fullName = HYPERPARAMETERS_JSON_LONG_NAME, doc = "JSON file containing hyperparameters. Optional if the PYTHON_IFOREST backend is used (if not specified, a default set of hyperparameters will be used); otherwise required.", optional = true)
    private File hyperparametersJSONFile;

    @Argument(fullName = "output", shortName = "O", doc = "Output prefix.")
    private String outputPrefix;
    private AvailableLabelsMode availableLabelsMode;

    @Argument(fullName = "model-backend", doc = "Backend to use for training models. JAVA_BGMM will use a pure Java implementation (ported from Python scikit-learn) of the Bayesian Gaussian Mixture Model. PYTHON_IFOREST will use the Python scikit-learn implementation of the IsolationForest method and will require that the corresponding Python dependencies are present in the environment. PYTHON_SCRIPT will use the script specified by the python-script argument. See the tool documentation for more details.")
    private VariantAnnotationsModelBackend modelBackend = VariantAnnotationsModelBackend.PYTHON_IFOREST;

    @Argument(fullName = "mode", doc = "Variant types for which to train models. Duplicate values will be ignored.", minElements = 1)
    public List<VariantType> variantTypes = new ArrayList(Arrays.asList(VariantType.SNP, VariantType.INDEL));

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/TrainVariantAnnotationsModel$AvailableLabelsMode.class */
    public enum AvailableLabelsMode {
        POSITIVE_ONLY,
        POSITIVE_UNLABELED
    }

    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    protected Object doWork() {
        validateArgumentsAndSetModes();
        this.logger.info("Starting training...");
        for (VariantType variantType : VariantType.values()) {
            if (this.variantTypes.contains(variantType)) {
                doModelingWorkForVariantType(variantType);
            }
        }
        this.logger.info(String.format("%s complete.", getClass().getSimpleName()));
        return null;
    }

    private void validateArgumentsAndSetModes() {
        IOUtils.canReadFile(this.inputAnnotationsFile);
        this.availableLabelsMode = this.inputUnlabeledAnnotationsFile != null ? AvailableLabelsMode.POSITIVE_UNLABELED : AvailableLabelsMode.POSITIVE_ONLY;
        if (this.inputUnlabeledAnnotationsFile != null) {
            IOUtils.canReadFile(this.inputUnlabeledAnnotationsFile);
            Utils.validateArg(LabeledVariantAnnotationsData.readAnnotationNames(this.inputAnnotationsFile).equals(LabeledVariantAnnotationsData.readAnnotationNames(this.inputUnlabeledAnnotationsFile)), "Annotation names must be identical for positive and unlabeled annotations.");
        }
        switch (this.modelBackend) {
            case JAVA_BGMM:
                Utils.validateArg(this.pythonScriptFile == null, "Python script should not be provided when using JAVA_BGMM backend.");
                IOUtils.canReadFile(this.hyperparametersJSONFile);
                this.logger.info("Running in JAVA_BGMM mode...");
                return;
            case PYTHON_IFOREST:
                Utils.validateArg(this.pythonScriptFile == null, "Python script should not be provided when using PYTHON_IFOREST backend.");
                this.pythonScriptFile = IOUtils.writeTempResource(new Resource(ISOLATION_FOREST_PYTHON_SCRIPT, TrainVariantAnnotationsModel.class));
                if (this.hyperparametersJSONFile == null) {
                    this.hyperparametersJSONFile = IOUtils.writeTempResource(new Resource(ISOLATION_FOREST_HYPERPARAMETERS_JSON, TrainVariantAnnotationsModel.class));
                }
                IOUtils.canReadFile(this.hyperparametersJSONFile);
                PythonScriptExecutor.checkPythonEnvironmentForPackage("argparse");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("h5py");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("numpy");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("sklearn");
                PythonScriptExecutor.checkPythonEnvironmentForPackage("dill");
                this.logger.info("Running in PYTHON_IFOREST mode...");
                return;
            case PYTHON_SCRIPT:
                Utils.validateArg(this.hyperparametersJSONFile != null, "Hyperparameters JSON must be provided when using PYTHON_SCRIPT backend.");
                IOUtils.canReadFile(this.pythonScriptFile);
                IOUtils.canReadFile(this.hyperparametersJSONFile);
                this.logger.info("Running in PYTHON_SCRIPT mode...");
                return;
            default:
                throw new GATKException.ShouldNeverReachHereException("Unknown model-backend mode.");
        }
    }

    private void doModelingWorkForVariantType(VariantType variantType) {
        List<String> readAnnotationNames = LabeledVariantAnnotationsData.readAnnotationNames(this.inputAnnotationsFile);
        double[][] readAnnotations = LabeledVariantAnnotationsData.readAnnotations(this.inputAnnotationsFile);
        List<Boolean> readLabel = LabeledVariantAnnotationsData.readLabel(this.inputAnnotationsFile, LabeledVariantAnnotationsData.TRAINING_LABEL);
        List<Boolean> readLabel2 = LabeledVariantAnnotationsData.readLabel(this.inputAnnotationsFile, LabeledVariantAnnotationsData.CALIBRATION_LABEL);
        List<Boolean> readLabel3 = LabeledVariantAnnotationsData.readLabel(this.inputAnnotationsFile, "snp");
        List<Boolean> list = variantType == VariantType.SNP ? readLabel3 : (List) readLabel3.stream().map(bool -> {
            return Boolean.valueOf(!bool.booleanValue());
        }).collect(Collectors.toList());
        List list2 = (List) Streams.zip(readLabel.stream(), list.stream(), (bool2, bool3) -> {
            return Boolean.valueOf(bool2.booleanValue() && bool3.booleanValue());
        }).collect(Collectors.toList());
        int numPassingFilter = numPassingFilter(list2);
        String variantType2 = variantType.toString();
        String str = "." + variantType.toString().toLowerCase();
        if (numPassingFilter <= 0) {
            throw new UserException.BadInput(String.format("Attempted to train %s model, but no suitable training sites were found in the provided annotations.", variantType2));
        }
        this.logger.info(String.format("Training %s model with %d training sites x %d annotations %s...", variantType2, Integer.valueOf(numPassingFilter), Integer.valueOf(readAnnotationNames.size()), readAnnotationNames));
        File subsetAnnotationsToTemporaryFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(readAnnotationNames, readAnnotations, list2);
        File file = null;
        int i = 0;
        if (this.availableLabelsMode == AvailableLabelsMode.POSITIVE_UNLABELED) {
            double[][] readAnnotations2 = LabeledVariantAnnotationsData.readAnnotations(this.inputUnlabeledAnnotationsFile);
            List<Boolean> readLabel4 = LabeledVariantAnnotationsData.readLabel(this.inputUnlabeledAnnotationsFile, "snp");
            List<Boolean> list3 = variantType == VariantType.SNP ? readLabel4 : (List) readLabel4.stream().map(bool4 -> {
                return Boolean.valueOf(!bool4.booleanValue());
            }).collect(Collectors.toList());
            i = numPassingFilter(list3);
            if (i <= 0) {
                throw new UserException.BadInput(String.format("Attempted to train %s model, but no suitable unlabeled sites were found in the provided annotations.", variantType2));
            }
            this.logger.info(String.format("Training %s model with %d unlabeled sites x %d annotations %s...", variantType2, Integer.valueOf(i), Integer.valueOf(readAnnotationNames.size()), readAnnotationNames));
            file = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(readAnnotationNames, readAnnotations2, list3);
        }
        trainAndSerializeModel(subsetAnnotationsToTemporaryFile, file, str);
        this.logger.info(String.format("%s model trained and serialized with output prefix \"%s\".", variantType2, this.outputPrefix + str));
        if (this.modelBackend == VariantAnnotationsModelBackend.JAVA_BGMM) {
            BGMMVariantAnnotationsScorer.preprocessAnnotationsWithBGMMAndWriteHDF5(readAnnotationNames, this.outputPrefix + str, subsetAnnotationsToTemporaryFile, this.logger);
        }
        this.logger.info(String.format("Scoring %d %s training sites...", Integer.valueOf(numPassingFilter), variantType2));
        this.logger.info(String.format("%s training scores written to %s.", variantType2, score(subsetAnnotationsToTemporaryFile, str, TRAINING_SCORES_HDF5_SUFFIX).getAbsolutePath()));
        List list4 = (List) Streams.zip(readLabel2.stream(), list.stream(), (bool5, bool6) -> {
            return Boolean.valueOf(bool5.booleanValue() && bool6.booleanValue());
        }).collect(Collectors.toList());
        int numPassingFilter2 = numPassingFilter(list4);
        if (numPassingFilter2 > 0) {
            this.logger.info(String.format("Scoring %d %s calibration sites...", Integer.valueOf(numPassingFilter2), variantType2));
            this.logger.info(String.format("%s calibration scores written to %s.", variantType2, score(LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(readAnnotationNames, readAnnotations, list4), str, CALIBRATION_SCORES_HDF5_SUFFIX).getAbsolutePath()));
        } else {
            this.logger.warn(String.format("No %s calibration sites were available.", variantType2));
        }
        if (this.availableLabelsMode != AvailableLabelsMode.POSITIVE_UNLABELED || file == null) {
            return;
        }
        this.logger.info(String.format("Scoring %d %s unlabeled sites...", Integer.valueOf(i), variantType2));
        this.logger.info(String.format("%s unlabeled scores written to %s.", variantType2, score(file, str, UNLABELED_SCORES_HDF5_SUFFIX).getAbsolutePath()));
    }

    private static int numPassingFilter(List<Boolean> list) {
        return (int) list.stream().filter(bool -> {
            return bool.booleanValue();
        }).count();
    }

    private void trainAndSerializeModel(File file, File file2, String str) {
        VariantAnnotationsModel pythonVariantAnnotationsModel;
        readAndValidateAnnotations(file, str);
        if (file2 != null) {
            readAndValidateAnnotations(file2, str);
        }
        switch (this.modelBackend) {
            case JAVA_BGMM:
                pythonVariantAnnotationsModel = new BGMMVariantAnnotationsModel(this.hyperparametersJSONFile);
                break;
            case PYTHON_IFOREST:
                pythonVariantAnnotationsModel = new PythonVariantAnnotationsModel(this.pythonScriptFile, this.hyperparametersJSONFile);
                break;
            case PYTHON_SCRIPT:
                pythonVariantAnnotationsModel = new PythonVariantAnnotationsModel(this.pythonScriptFile, this.hyperparametersJSONFile);
                break;
            default:
                throw new GATKException.ShouldNeverReachHereException("Unknown model mode.");
        }
        pythonVariantAnnotationsModel.trainAndSerialize(file, file2, this.outputPrefix + str);
    }

    private void readAndValidateAnnotations(File file, String str) {
        List<String> readAnnotationNames = LabeledVariantAnnotationsData.readAnnotationNames(file);
        double[][] readAnnotations = LabeledVariantAnnotationsData.readAnnotations(file);
        int size = readAnnotationNames.size();
        int length = readAnnotations.length;
        Utils.validateArg(size > 0, "Number of annotation names must be positive.");
        Utils.validateArg(length > 0, "Number of data points must be positive.");
        int length2 = readAnnotations[0].length;
        Utils.validateArg(size == length2, "Number of annotation names must match the number of features in the annotation data.");
        ArrayList arrayList = new ArrayList(length2);
        IntStream.range(0, length2).forEach(i -> {
            if (new Variance().evaluate(IntStream.range(0, length).mapToDouble(i -> {
                return readAnnotations[i][i];
            }).toArray()) == 0.0d) {
                this.logger.warn(String.format("All values of the annotation %s are identical in the training data for the %s model.", readAnnotationNames.get(i), this.outputPrefix + str));
            }
            if (IntStream.range(0, length).boxed().map(num -> {
                return Double.valueOf(readAnnotations[num.intValue()][i]);
            }).allMatch(d -> {
                return Double.isNaN(d.doubleValue());
            })) {
                arrayList.add((String) readAnnotationNames.get(i));
            }
        });
        if (!arrayList.isEmpty()) {
            throw new UserException.BadInput(String.format("All values of the following annotations are missing in the training data for the %s model: %s. Consider repeating the extraction step without specifying these annotations. ", this.outputPrefix + str, arrayList));
        }
    }

    private File score(File file, String str, String str2) {
        VariantAnnotationsScorer pythonVariantAnnotationsScorer;
        switch (this.modelBackend) {
            case JAVA_BGMM:
                pythonVariantAnnotationsScorer = BGMMVariantAnnotationsScorer.deserialize(new File(this.outputPrefix + str + ".bgmmScorer.ser"));
                break;
            case PYTHON_IFOREST:
            case PYTHON_SCRIPT:
                pythonVariantAnnotationsScorer = new PythonVariantAnnotationsScorer(this.pythonScriptFile, new File(this.outputPrefix + str + ".scorer.pkl"));
                break;
            default:
                throw new GATKException.ShouldNeverReachHereException("Unknown model mode.");
        }
        File file2 = new File(this.outputPrefix + str + str2);
        pythonVariantAnnotationsScorer.score(file, file2);
        return file2;
    }
}
