package weka.distributed;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.List;
import javax.imageio.ImageIO;
import org.tc33.jheatchart.HeatChart;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.matrix.Matrix;
import weka.core.stats.ArffSummaryNumericMetric;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:weka/distributed/CorrelationMatrixRowReduceTask.class */
public class CorrelationMatrixRowReduceTask implements Serializable {
    private static final long serialVersionUID = -2314677791571204717L;

    /* JADX WARN: Type inference failed for: r0v4, types: [double[], double[][]] */
    public static Image getHeatMapForMatrix(Matrix matrix, List<String> list) {
        double[][] array = matrix.getArray();
        ?? r0 = new double[array.length];
        for (int i = 0; i < array.length; i++) {
            r0[(array.length - 1) - i] = array[i];
        }
        String[] strArr = new String[list.size()];
        String[] strArr2 = new String[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            strArr[i2] = list.get(i2);
            strArr2[(list.size() - 1) - i2] = list.get(i2);
        }
        HeatChart heatChart = new HeatChart(r0, true);
        heatChart.setTitle("Correlation matrix heat map");
        heatChart.setCellSize(new Dimension(30, 30));
        heatChart.setHighValueColour(Color.RED);
        heatChart.setLowValueColour(Color.BLUE);
        heatChart.setXValues(strArr);
        heatChart.setYValues(strArr2);
        return heatChart.getChartImage();
    }

    public static void writeHeatMapImage(Image image, OutputStream outputStream) throws IOException {
        ImageIO.write((BufferedImage) image, "png", outputStream);
        outputStream.flush();
        outputStream.close();
    }

    public double[] aggregate(int i, List<double[]> list, List<int[]> list2, Instances instances, boolean z, boolean z2, boolean z3) throws DistributedWekaException {
        StringBuilder sb = new StringBuilder();
        Instances stripSummaryAtts = CSVToARFFHeaderReduceTask.stripSummaryAtts(instances);
        if (stripSummaryAtts.classIndex() >= 0 && z3) {
            sb.append("" + (stripSummaryAtts.classIndex() + 1)).append(",");
        }
        for (int i2 = 0; i2 < stripSummaryAtts.numAttributes(); i2++) {
            if (!stripSummaryAtts.attribute(i2).isNumeric()) {
                sb.append("" + (i2 + 1)).append(",");
            }
        }
        if (sb.length() > 0) {
            Remove remove = new Remove();
            sb.deleteCharAt(sb.length() - 1);
            remove.setAttributeIndices(sb.toString());
            remove.setInvertSelection(false);
            try {
                remove.setInputFormat(stripSummaryAtts);
                stripSummaryAtts = Filter.useFilter(stripSummaryAtts, remove);
            } catch (Exception e) {
                throw new DistributedWekaException(e);
            }
        }
        Attribute attribute = stripSummaryAtts.attribute(i);
        Attribute attribute2 = instances.attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX + attribute.name());
        if (attribute2 == null) {
            throw new DistributedWekaException("Was unable to find the summary stats attribute for original attribute '" + attribute.name() + "' corresponding to matrix row number: " + i);
        }
        if (list.size() == 0) {
            throw new DistributedWekaException("Nothing to aggregate!");
        }
        double[] dArr = (double[]) list.get(0).clone();
        for (int i3 = 1; i3 < list.size(); i3++) {
            double[] dArr2 = list.get(i3);
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] + dArr2[i4];
            }
        }
        int[] iArr = null;
        if (!z) {
            iArr = (int[]) list2.get(0).clone();
            for (int i6 = 1; i6 < list2.size(); i6++) {
                int[] iArr2 = list2.get(i6);
                for (int i7 = 0; i7 < iArr2.length; i7++) {
                    int i8 = i7;
                    iArr[i8] = iArr[i8] + iArr2[i7];
                }
            }
        }
        double[] attributeToStatsArray = CSVToARFFHeaderReduceTask.attributeToStatsArray(attribute2);
        for (int i9 = 0; i9 < dArr.length; i9++) {
            double[] attributeToStatsArray2 = CSVToARFFHeaderReduceTask.attributeToStatsArray(instances.attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX + stripSummaryAtts.attribute(i9).name()));
            double d = z ? attributeToStatsArray[ArffSummaryNumericMetric.COUNT.ordinal()] + attributeToStatsArray[ArffSummaryNumericMetric.MISSING.ordinal()] : iArr[i9];
            if (z2) {
                if (d > 1.0d) {
                    int i10 = i9;
                    dArr[i10] = dArr[i10] / (d - 1.0d);
                } else if (d == 1.0d) {
                    dArr[i9] = Double.POSITIVE_INFINITY;
                } else {
                    dArr[i9] = Utils.missingValue();
                }
            } else if (i != i9 && d > 1.0d) {
                double d2 = attributeToStatsArray[ArffSummaryNumericMetric.STDDEV.ordinal()];
                double d3 = attributeToStatsArray2[ArffSummaryNumericMetric.STDDEV.ordinal()];
                if (d2 * d3 > 0.0d) {
                    dArr[i9] = dArr[i9] / (((d - 1.0d) * d2) * d3);
                }
            } else if (z) {
                dArr[i9] = 1.0d;
            } else {
                dArr[i9] = Utils.missingValue();
            }
        }
        return dArr;
    }
}
