package com.github.chen0040.nbc;

import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.CountRepository;
import com.github.chen0040.data.utils.discretizers.KMeansDiscretizer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/chen0040/nbc/NBC.class */
public class NBC {
    private static final Logger logger = LoggerFactory.getLogger(NBC.class);
    private CountRepository model = new CountRepository();
    private KMeansDiscretizer inputDiscretizer = new KMeansDiscretizer();
    private final List<String> classLabels = new ArrayList();

    protected boolean isValidTrainingSample(DataRow dataRow) {
        return !dataRow.getCategoricalTargetColumnNames().isEmpty();
    }

    public void copy(NBC nbc) {
        this.model = nbc.model == null ? null : nbc.model.makeCopy();
    }

    public Object clone() throws CloneNotSupportedException {
        NBC nbc = (NBC) super.clone();
        nbc.copy(this);
        return nbc;
    }

    public CountRepository getModel() {
        return this.model;
    }

    public String classify(DataRow dataRow) {
        HashMap<String, Double> scores = getScores(dataRow);
        double d = 0.0d;
        String str = null;
        for (String str2 : scores.keySet()) {
            double doubleValue = scores.get(str2).doubleValue();
            if (d < doubleValue) {
                d = doubleValue;
                str = str2;
            }
        }
        return str;
    }

    public HashMap<String, Double> getScores(DataRow dataRow) {
        HashMap<String, Double> hashMap = new HashMap<>();
        for (String str : this.classLabels) {
            String str2 = "ClassLabel=" + str;
            double probability = 1.0d * this.model.getProbability(str2);
            List categoricalColumnNames = dataRow.getCategoricalColumnNames();
            int size = categoricalColumnNames.size();
            for (int i = 0; i < size; i++) {
                String str3 = (String) categoricalColumnNames.get(i);
                probability *= this.model.getConditionalProbability(str2, str3 + "=" + dataRow.getCategoricalCell(str3));
            }
            List columnNames = dataRow.getColumnNames();
            int size2 = columnNames.size();
            for (int i2 = 0; i2 < size2; i2++) {
                String str4 = (String) columnNames.get(i2);
                probability *= this.model.getConditionalProbability(str2, str4 + "=" + this.inputDiscretizer.discretize(dataRow.getCell(str4), str4));
            }
            hashMap.put(str, Double.valueOf(probability));
        }
        return hashMap;
    }

    private void initializeClassLabels(DataFrame dataFrame) {
        HashSet hashSet = new HashSet();
        int rowCount = dataFrame.rowCount();
        for (int i = 0; i < rowCount; i++) {
            DataRow row = dataFrame.row(i);
            if (isValidTrainingSample(row)) {
                hashSet.add(row.categoricalTarget());
            }
        }
        this.classLabels.clear();
        this.classLabels.addAll(hashSet);
    }

    public void fit(DataFrame dataFrame) {
        this.inputDiscretizer.fit(dataFrame);
        this.model = new CountRepository();
        initializeClassLabels(dataFrame);
        int rowCount = dataFrame.rowCount();
        for (int i = 0; i < rowCount; i++) {
            DataRow row = dataFrame.row(i);
            String str = "ClassLabel=" + row.categoricalTarget();
            List categoricalColumnNames = row.getCategoricalColumnNames();
            int size = categoricalColumnNames.size();
            for (int i2 = 0; i2 < size; i2++) {
                String str2 = (String) categoricalColumnNames.get(i2);
                this.model.addSupportCount(new String[]{str, str2 + "=" + row.getCategoricalCell(str2)});
                this.model.addSupportCount(new String[]{str});
                this.model.addSupportCount(new String[0]);
            }
            List columnNames = row.getColumnNames();
            int size2 = columnNames.size();
            for (int i3 = 0; i3 < size2; i3++) {
                String str3 = (String) columnNames.get(i3);
                this.model.addSupportCount(new String[]{str, str3 + "=" + this.inputDiscretizer.discretize(row.getCell(str3), str3)});
                this.model.addSupportCount(new String[]{str});
                this.model.addSupportCount(new String[0]);
            }
        }
    }
}
