package org.broadinstitute.hellbender.tools.copynumber.utils.segmentation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.utils.optimization.PersistenceOptimizer;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/utils/segmentation/KernelSegmenter.class */
public final class KernelSegmenter<DATA> {
    private static final Logger logger = LogManager.getLogger(KernelSegmenter.class);
    private static final int RANDOM_SEED = 1216;
    private static final double EPSILON = 1.0E-10d;
    private final List<DATA> data;

    /* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/utils/segmentation/KernelSegmenter$ChangepointSortOrder.class */
    public enum ChangepointSortOrder {
        BACKWARD_SELECTION,
        INDEX
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/utils/segmentation/KernelSegmenter$Cost.class */
    public static final class Cost {
        private final double D;
        private final double[] W;
        private final double V;
        private final double C;

        private Cost(double d, double[] dArr, double d2, double d3) {
            this.D = d;
            this.W = dArr;
            this.V = d2;
            this.C = d3;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/utils/segmentation/KernelSegmenter$Segment.class */
    public static final class Segment {
        private final int start;
        private final int end;
        private final double cost;

        private Segment(int i, int i2, double d) {
            this.start = i;
            this.end = i2;
            this.cost = d;
        }

        private Segment(int i, int i2, RealMatrix realMatrix, double[] dArr) {
            this(i, i2, KernelSegmenter.calculateSegmentCost(i, i2, realMatrix, dArr).C);
        }
    }

    public KernelSegmenter(List<DATA> list) {
        this.data = Collections.unmodifiableList(new ArrayList((Collection) Utils.nonNull(list)));
    }

    public List<Integer> findChangepoints(int i, BiFunction<DATA, DATA, Double> biFunction, int i2, List<Integer> list, double d, double d2, ChangepointSortOrder changepointSortOrder) {
        ParamUtils.isPositiveOrZero(i, "Maximum number of changepoints must be non-negative.");
        ParamUtils.isPositive(i2, "Dimension of kernel approximation must be positive.");
        Utils.validateArg(!list.isEmpty(), "At least one window size must be provided.");
        Utils.validateArg(list.stream().allMatch(num -> {
            return num.intValue() > 0;
        }), "Window sizes must all be positive.");
        Utils.validateArg(list.stream().distinct().count() == ((long) list.size()), "Window sizes must all be unique.");
        ParamUtils.isPositiveOrZero(d, "Linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        ParamUtils.isPositiveOrZero(d2, "Log-linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        if (i == 0) {
            logger.warn("No changepoints were requested, returning an empty list...");
            return Collections.emptyList();
        }
        if (this.data.isEmpty()) {
            logger.warn("No data points were provided, returning an empty list...");
            return Collections.emptyList();
        }
        logger.debug(String.format("Finding up to %d changepoints in %d data points...", Integer.valueOf(i), Integer.valueOf(this.data.size())));
        RandomGenerator createRandomGenerator = RandomGeneratorFactory.createRandomGenerator(new Random(1216L));
        logger.debug("Calculating low-rank approximation to kernel matrix...");
        RealMatrix calculateReducedObservationMatrix = calculateReducedObservationMatrix(createRandomGenerator, this.data, biFunction, i2);
        double[] calculateKernelApproximationDiagonal = calculateKernelApproximationDiagonal(calculateReducedObservationMatrix);
        logger.debug(String.format("Finding changepoint candidates for all window sizes %s...", list.toString()));
        List<Integer> findChangepointCandidates = findChangepointCandidates(this.data, calculateReducedObservationMatrix, calculateKernelApproximationDiagonal, i, list);
        logger.debug("Performing backward model selection on changepoint candidates...");
        return (List) selectChangepoints(findChangepointCandidates, i, d, d2, calculateReducedObservationMatrix, calculateKernelApproximationDiagonal).stream().sorted((num2, num3) -> {
            if (changepointSortOrder.equals(ChangepointSortOrder.INDEX)) {
                return Integer.compare(num2.intValue(), num3.intValue());
            }
            return 0;
        }).collect(Collectors.toList());
    }

    private static <DATA> RealMatrix calculateReducedObservationMatrix(RandomGenerator randomGenerator, final List<DATA> list, final BiFunction<DATA, DATA, Double> biFunction, int i) {
        if (i > list.size()) {
            logger.warn(String.format("Specified dimension of the kernel approximation (%d) exceeds the number of data points (%d) to segment; using all data points to calculate kernel matrix.", Integer.valueOf(i), Integer.valueOf(list.size())));
        }
        int min = Math.min(i, list.size());
        logger.debug(String.format("Subsampling %d points from data to find kernel approximation...", Integer.valueOf(min)));
        final List<DATA> list2 = min == list.size() ? list : (List) IntStream.range(0, min).mapToObj(i2 -> {
            return list.get(randomGenerator.nextInt(list.size()));
        }).collect(Collectors.toList());
        logger.debug(String.format("Calculating kernel matrix of subsampled data (%d x %d)...", Integer.valueOf(min), Integer.valueOf(min)));
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(min, min);
        for (int i3 = 0; i3 < min; i3++) {
            for (int i4 = 0; i4 < i3; i4++) {
                double doubleValue = biFunction.apply(list2.get(i3), list2.get(i4)).doubleValue();
                array2DRowRealMatrix.setEntry(i3, i4, doubleValue);
                array2DRowRealMatrix.setEntry(i4, i3, doubleValue);
            }
            array2DRowRealMatrix.setEntry(i3, i3, biFunction.apply(list2.get(i3), list2.get(i3)).doubleValue());
        }
        logger.debug(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", Integer.valueOf(min), Integer.valueOf(min)));
        final SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(array2DRowRealMatrix);
        logger.debug(String.format("Calculating reduced observation matrix (%d x %d)...", Integer.valueOf(list.size()), Integer.valueOf(min)));
        final double[] array = Arrays.stream(singularValueDecomposition.getSingularValues()).map(Math::sqrt).map(d -> {
            return 1.0d / (d + 1.0E-10d);
        }).toArray();
        Array2DRowRealMatrix array2DRowRealMatrix2 = new Array2DRowRealMatrix(min, min);
        array2DRowRealMatrix2.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter.1
            public double visit(int i5, int i6, double d2) {
                return singularValueDecomposition.getU().getEntry(i5, i6) * array[i6];
            }
        });
        Array2DRowRealMatrix array2DRowRealMatrix3 = new Array2DRowRealMatrix(list.size(), min);
        array2DRowRealMatrix3.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter.2
            public double visit(int i5, int i6, double d2) {
                return ((Double) biFunction.apply(list.get(i5), list2.get(i6))).doubleValue();
            }
        });
        return array2DRowRealMatrix3.multiply(array2DRowRealMatrix2);
    }

    private static double[] calculateKernelApproximationDiagonal(RealMatrix realMatrix) {
        return new IndexRange(0, realMatrix.getRowDimension()).mapToDouble(i -> {
            return MathUtils.square(realMatrix.getRowVector(i).getNorm());
        });
    }

    private static <DATA> List<Integer> findChangepointCandidates(List<DATA> list, RealMatrix realMatrix, double[] dArr, int i, List<Integer> list2) {
        ArrayList arrayList = new ArrayList(list2.size() * i);
        Iterator<Integer> it = list2.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            logger.debug(String.format("Calculating local changepoints costs for window size %d...", Integer.valueOf(intValue)));
            if (2 * intValue > list.size()) {
                logger.warn(String.format("Number of points needed to calculate local changepoint costs (2 * window size = %d) exceeds number of data points (%d).  Local changepoint costs will not be calculated for this window size.", Integer.valueOf(2 * intValue), Integer.valueOf(list.size())));
            } else {
                double[] calculateWindowCosts = calculateWindowCosts(realMatrix, dArr, intValue);
                logger.debug(String.format("Finding local minima of local changepoint costs for window size %d...", Integer.valueOf(intValue)));
                ArrayList arrayList2 = new ArrayList(new PersistenceOptimizer(calculateWindowCosts).getMinimaIndices());
                arrayList2.remove((Object) 0);
                arrayList2.remove(Integer.valueOf(list.size() - 1));
                arrayList.addAll(arrayList2.subList(0, Math.min(i, arrayList2.size())));
            }
        }
        if (arrayList.isEmpty()) {
            logger.warn("No changepoint candidates were found.  The specified window sizes may be inappropriate, or there may be insufficient data points.");
        }
        return arrayList;
    }

    private static List<Integer> selectChangepoints(List<Integer> list, int i, double d, double d2, RealMatrix realMatrix, double[] dArr) {
        ArrayList arrayList = new ArrayList(list.size());
        int rowDimension = realMatrix.getRowDimension();
        List list2 = (List) IntStream.range(0, i + 1).mapToObj(i2 -> {
            return Double.valueOf(calculateChangepointPenalty(i2, d, d2, rowDimension));
        }).collect(Collectors.toList());
        List list3 = (List) list.stream().sorted().distinct().map(num -> {
            return Integer.valueOf(Math.min(num.intValue() + 1, rowDimension - 1));
        }).collect(Collectors.toList());
        list3.add(0, 0);
        List list4 = (List) list.stream().sorted().distinct().collect(Collectors.toList());
        list4.add(Integer.valueOf(rowDimension - 1));
        int size = list3.size();
        List list5 = (List) IntStream.range(0, size).mapToObj(i3 -> {
            return new Segment(((Integer) list3.get(i3)).intValue(), ((Integer) list4.get(i3)).intValue(), realMatrix, dArr);
        }).collect(Collectors.toList());
        ArrayList arrayList2 = new ArrayList(Collections.singletonList(Double.valueOf(list5.stream().mapToDouble(segment -> {
            return segment.cost;
        }).sum())));
        List list6 = (List) IntStream.range(0, size - 1).mapToObj(i4 -> {
            return Double.valueOf(((Segment) list5.get(i4)).cost + ((Segment) list5.get(i4 + 1)).cost);
        }).collect(Collectors.toList());
        List list7 = (List) IntStream.range(0, size - 1).mapToObj(i5 -> {
            return Double.valueOf(new Segment(((Integer) list3.get(i5)).intValue(), ((Integer) list4.get(i5 + 1)).intValue(), realMatrix, dArr).cost);
        }).collect(Collectors.toList());
        List list8 = (List) IntStream.range(0, size - 1).mapToObj(i6 -> {
            return Double.valueOf(((Double) list6.get(i6)).doubleValue() - ((Double) list7.get(i6)).doubleValue());
        }).collect(Collectors.toList());
        for (int i7 = 0; i7 < size - 1; i7++) {
            int indexOf = list8.indexOf(Collections.max(list8));
            double doubleValue = ((Double) list7.get(indexOf)).doubleValue();
            int i8 = ((Segment) list5.get(indexOf)).start;
            int i9 = ((Segment) list5.get(indexOf)).end;
            int i10 = ((Segment) list5.get(indexOf + 1)).end;
            list5.remove(indexOf);
            list5.remove(indexOf);
            list5.add(indexOf, new Segment(i8, i10, doubleValue));
            list6.remove(indexOf);
            list7.remove(indexOf);
            list8.remove(indexOf);
            if (indexOf > 0) {
                list6.set(indexOf - 1, Double.valueOf(((Segment) list5.get(indexOf - 1)).cost + ((Segment) list5.get(indexOf)).cost));
                list7.set(indexOf - 1, Double.valueOf(new Segment(((Segment) list5.get(indexOf - 1)).start, i10, realMatrix, dArr).cost));
                list8.set(indexOf - 1, Double.valueOf(((Double) list6.get(indexOf - 1)).doubleValue() - ((Double) list7.get(indexOf - 1)).doubleValue()));
            }
            if (indexOf < list5.size() - 1) {
                list6.set(indexOf, Double.valueOf(((Segment) list5.get(indexOf)).cost + ((Segment) list5.get(indexOf + 1)).cost));
                list7.set(indexOf, Double.valueOf(new Segment(i8, ((Segment) list5.get(indexOf + 1)).end, realMatrix, dArr).cost));
                list8.set(indexOf, Double.valueOf(((Double) list6.get(indexOf)).doubleValue() - ((Double) list7.get(indexOf)).doubleValue()));
            }
            arrayList2.add(0, Double.valueOf(list5.stream().mapToDouble(segment2 -> {
                return segment2.cost;
            }).sum()));
            arrayList.add(0, Integer.valueOf(i9));
        }
        List list9 = (List) IntStream.range(0, Math.min(i, arrayList.size()) + 1).mapToObj(i11 -> {
            return Double.valueOf(((Double) arrayList2.get(i11)).doubleValue() + ((Double) list2.get(i11)).doubleValue());
        }).collect(Collectors.toList());
        int indexOf2 = list9.indexOf(Collections.min(list9));
        logger.info(String.format("Found %d changepoints after applying the changepoint penalty.", Integer.valueOf(indexOf2)));
        return arrayList.subList(0, indexOf2);
    }

    private static double calculateChangepointPenalty(int i, double d, double d2, int i2) {
        return (d * i) + (d2 * i * Math.log(i2 / (i + 1.0E-10d)));
    }

    private static Cost calculateSegmentCost(int i, int i2, RealMatrix realMatrix, double[] dArr) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        double d = dArr[i];
        double[] copyOf = Arrays.copyOf(realMatrix.getRow(i), columnDimension);
        double sum = Arrays.stream(copyOf).map(d2 -> {
            return d2 * d2;
        }).sum();
        Iterator it = (i <= i2 ? (List) IntStream.range(i + 1, i2 + 1).boxed().collect(Collectors.toList()) : (List) IntStream.concat(IntStream.range(i + 1, rowDimension), IntStream.range(0, i2 + 1)).boxed().collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            d += dArr[intValue];
            double d3 = 0.0d;
            for (int i3 = 0; i3 < columnDimension; i3++) {
                d3 += realMatrix.getEntry(intValue, i3) * copyOf[i3];
                int i4 = i3;
                copyOf[i4] = copyOf[i4] + realMatrix.getEntry(intValue, i3);
            }
            sum += (2.0d * d3) + dArr[intValue];
        }
        return new Cost(d, copyOf, sum, d - (sum / (r21.size() + 1)));
    }

    private static double[] calculateWindowCosts(RealMatrix realMatrix, double[] dArr, int i) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int i2 = (((0 - i) + 1) + rowDimension) % rowDimension;
        int i3 = (0 + i) % rowDimension;
        Cost calculateSegmentCost = calculateSegmentCost(i2, 0, realMatrix, dArr);
        Cost calculateSegmentCost2 = calculateSegmentCost(0 + 1, i3, realMatrix, dArr);
        Cost calculateSegmentCost3 = calculateSegmentCost(i2, i3, realMatrix, dArr);
        double d = calculateSegmentCost.D;
        double[] copyOf = Arrays.copyOf(calculateSegmentCost.W, columnDimension);
        double d2 = calculateSegmentCost.V;
        double d3 = calculateSegmentCost.C;
        double d4 = calculateSegmentCost2.D;
        double[] copyOf2 = Arrays.copyOf(calculateSegmentCost2.W, columnDimension);
        double d5 = calculateSegmentCost2.V;
        double d6 = calculateSegmentCost2.C;
        double d7 = calculateSegmentCost3.D;
        double[] copyOf3 = Arrays.copyOf(calculateSegmentCost3.W, columnDimension);
        double d8 = calculateSegmentCost3.V;
        double d9 = calculateSegmentCost3.C;
        double[] dArr2 = new double[rowDimension];
        dArr2[0] = (d3 + d6) - d9;
        double d10 = 1.0d / i;
        for (int i4 = 0; i4 < rowDimension; i4++) {
            int i5 = (i4 + 1) % rowDimension;
            int i6 = (i3 + 1) % rowDimension;
            double d11 = d - dArr[i2];
            double d12 = 0.0d;
            for (int i7 = 0; i7 < columnDimension; i7++) {
                d12 += realMatrix.getEntry(i2, i7) * copyOf[i7];
                int i8 = i7;
                copyOf[i8] = copyOf[i8] - realMatrix.getEntry(i2, i7);
            }
            double d13 = d2 + ((-2.0d) * d12) + dArr[i2];
            d = d11 + dArr[i5];
            double d14 = 0.0d;
            for (int i9 = 0; i9 < columnDimension; i9++) {
                d14 += realMatrix.getEntry(i5, i9) * copyOf[i9];
                int i10 = i9;
                copyOf[i10] = copyOf[i10] + realMatrix.getEntry(i5, i9);
            }
            d2 = d13 + (2.0d * d14) + dArr[i5];
            double d15 = d - (d2 * d10);
            double d16 = d4 - dArr[i5];
            double d17 = 0.0d;
            for (int i11 = 0; i11 < columnDimension; i11++) {
                d17 += realMatrix.getEntry(i5, i11) * copyOf2[i11];
                int i12 = i11;
                copyOf2[i12] = copyOf2[i12] - realMatrix.getEntry(i5, i11);
            }
            double d18 = d5 + ((-2.0d) * d17) + dArr[i5];
            d4 = d16 + dArr[i6];
            double d19 = 0.0d;
            for (int i13 = 0; i13 < columnDimension; i13++) {
                d19 += realMatrix.getEntry(i6, i13) * copyOf2[i13];
                int i14 = i13;
                copyOf2[i14] = copyOf2[i14] + realMatrix.getEntry(i6, i13);
            }
            d5 = d18 + (2.0d * d19) + dArr[i6];
            double d20 = d4 - (d5 * d10);
            double d21 = d7 - dArr[i2];
            double d22 = 0.0d;
            for (int i15 = 0; i15 < columnDimension; i15++) {
                d22 += realMatrix.getEntry(i2, i15) * copyOf3[i15];
                int i16 = i15;
                copyOf3[i16] = copyOf3[i16] - realMatrix.getEntry(i2, i15);
            }
            double d23 = d8 + ((-2.0d) * d22) + dArr[i2];
            d7 = d21 + dArr[i6];
            double d24 = 0.0d;
            for (int i17 = 0; i17 < columnDimension; i17++) {
                d24 += realMatrix.getEntry(i6, i17) * copyOf3[i17];
                int i18 = i17;
                copyOf3[i18] = copyOf3[i18] + realMatrix.getEntry(i6, i17);
            }
            d8 = d23 + (2.0d * d24) + dArr[i6];
            dArr2[i5] = (d15 + d20) - (d7 - ((0.5d * d8) * d10));
            i2 = (i2 + 1) % rowDimension;
            i3 = i6;
        }
        return dArr2;
    }
}
