package de.unidue.ltl.evaluation.visualization;

import de.unidue.ltl.evaluation.core.EvaluationData;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Shape;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.List;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.annotations.XYTextAnnotation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.data.xy.DefaultXYZDataset;
import org.jfree.data.xy.XYZDataset;

/* loaded from: input_file:de/unidue/ltl/evaluation/visualization/ConfusionMatrixHeatmap.class */
public class ConfusionMatrixHeatmap {
    private ConfusionMatrix<String> cf;
    private List<String> labels;

    public ConfusionMatrixHeatmap(EvaluationData<String> evaluationData) {
        this.cf = new ConfusionMatrix<>(evaluationData);
        this.labels = this.cf.getLabels();
    }

    public ConfusionMatrixHeatmap(ConfusionMatrix<String> confusionMatrix) {
        this.cf = confusionMatrix;
        this.labels = this.cf.getLabels();
    }

    public void writePlot(File file) throws IOException {
        DecimalFormat decimalFormat = new DecimalFormat();
        decimalFormat.setMaximumFractionDigits(2);
        decimalFormat.setMinimumFractionDigits(2);
        XYBlockRenderer xYBlockRenderer = new XYBlockRenderer();
        xYBlockRenderer.setBlockHeight(1.0d);
        xYBlockRenderer.setBlockWidth(1.0d);
        xYBlockRenderer.setPaintScale(new GrayPaintScale(0.0d, 1.0d));
        for (int i = 0; i < this.labels.size(); i++) {
            xYBlockRenderer.setSeriesShape(i, (Shape) null);
            xYBlockRenderer.setSeriesCreateEntities(i, false);
        }
        NumberAxis numberAxis = new NumberAxis("Gold");
        NumberAxis numberAxis2 = new NumberAxis("Predicted");
        numberAxis.setTickUnit(new NumberTickUnit(1.0d));
        numberAxis2.setTickUnit(new NumberTickUnit(1.0d));
        numberAxis.setRange(0.5d, this.labels.size() + 0.5d);
        numberAxis2.setRange(0.5d, this.labels.size() + 0.5d);
        numberAxis2.setInverted(true);
        XYZDataset dataset = getDataset();
        XYPlot xYPlot = new XYPlot(dataset, numberAxis, numberAxis2, xYBlockRenderer);
        xYPlot.setOutlinePaint(Color.black);
        for (int i2 = 0; i2 < this.labels.size(); i2++) {
            for (int i3 = 0; i3 < this.labels.size(); i3++) {
                XYTextAnnotation xYTextAnnotation = new XYTextAnnotation(decimalFormat.format(Double.valueOf(1.0d - dataset.getZValue(i2, i3))), dataset.getXValue(i2, i3), dataset.getYValue(i2, i3));
                xYTextAnnotation.setPaint(Color.black);
                xYPlot.addAnnotation(xYTextAnnotation);
            }
        }
        JFreeChart jFreeChart = new JFreeChart("Confusion Matrix Heatmap", JFreeChart.DEFAULT_TITLE_FONT, xYPlot, false);
        jFreeChart.setBackgroundPaint(Color.white);
        new ChartPanel(jFreeChart, true, true, true, true, true).setPreferredSize(new Dimension(900, 850));
        ChartUtilities.writeChartAsPNG(new FileOutputStream(file), jFreeChart, 50 * this.labels.size(), 40 * this.labels.size());
    }

    private XYZDataset getDataset() {
        DefaultXYZDataset defaultXYZDataset = new DefaultXYZDataset();
        for (int i = 0; i < this.labels.size(); i++) {
            String str = this.labels.get(i);
            defaultXYZDataset.addSeries(i + ": " + str, getSeries(str, i, this.cf.getNumberOfEntries()));
        }
        return defaultXYZDataset;
    }

    private double[][] getSeries(String str, int i, long j) {
        double[][] dArr = new double[3][this.labels.size()];
        for (int i2 = 0; i2 < this.labels.size(); i2++) {
            double d = 0.0d;
            if (!str.equals(this.labels.get(i2))) {
                d = this.cf.getNumberOfConfusions(str, r0) / j;
            }
            dArr[0][i2] = i + 1;
            dArr[1][i2] = i2 + 1;
            dArr[2][i2] = 1.0d - d;
        }
        return dArr;
    }
}
