package weka.distributed;

import distributed.core.DistributedJobConfig;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import weka.classifiers.Classifier;
import weka.clusterers.Clusterer;
import weka.core.Attribute;
import weka.core.BatchPredictor;
import weka.core.DenseInstance;
import weka.core.Environment;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.Utils;

/* loaded from: input_file:weka/distributed/WekaScoringMapTask.class */
public class WekaScoringMapTask implements Serializable {
    private static final long serialVersionUID = 146378352037860956L;
    protected ScoringModel m_model;
    protected Instances m_batchScoringData;
    protected boolean m_isUsingStringAttributes;
    protected int[] m_attributeMap;
    protected boolean m_equalHeaders;
    protected int m_batchSize = 1000;
    protected Map<String, String> m_missingMismatch = new HashMap();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/distributed/WekaScoringMapTask$ClassifierScoringModel.class */
    public static class ClassifierScoringModel extends ScoringModel {
        private static final long serialVersionUID = -5823090343185762045L;
        protected Classifier m_model;

        public ClassifierScoringModel(Object obj) {
            super(obj);
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public void setModel(Object obj) {
            this.m_model = (Classifier) obj;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public Object getModel() {
            return this.m_model;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public List<String> getPredictionLabels() {
            if (this.m_modelHeader == null) {
                return null;
            }
            if (this.m_modelHeader.classAttribute().isNominal() && this.m_predictionLabels == null) {
                this.m_predictionLabels = new ArrayList();
                for (int i = 0; i < this.m_modelHeader.classAttribute().numValues(); i++) {
                    this.m_predictionLabels.add(this.m_modelHeader.classAttribute().value(i));
                }
            }
            return this.m_predictionLabels;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public double[] distributionForInstance(Instance instance) throws Exception {
            return this.m_model.distributionForInstance(instance);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/distributed/WekaScoringMapTask$ClustererScoringModel.class */
    public static class ClustererScoringModel extends ScoringModel {
        private static final long serialVersionUID = 6415571646466462751L;
        protected Clusterer m_model;

        public ClustererScoringModel(Object obj) {
            super(obj);
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public void setModel(Object obj) {
            this.m_model = (Clusterer) obj;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public Object getModel() {
            return this.m_model;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public List<String> getPredictionLabels() throws DistributedWekaException {
            if (this.m_predictionLabels == null) {
                this.m_predictionLabels = new ArrayList();
                for (int i = 0; i < this.m_model.numberOfClusters(); i++) {
                    try {
                        this.m_predictionLabels.add("Cluster_" + i);
                    } catch (Exception e) {
                        throw new DistributedWekaException(e);
                    }
                }
            }
            return this.m_predictionLabels;
        }

        @Override // weka.distributed.WekaScoringMapTask.ScoringModel
        public double[] distributionForInstance(Instance instance) throws Exception {
            return this.m_model.distributionForInstance(instance);
        }

        public int numberOfClusters() throws Exception {
            return this.m_model.numberOfClusters();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/distributed/WekaScoringMapTask$ScoringModel.class */
    public static abstract class ScoringModel implements Serializable {
        private static final long serialVersionUID = 6418927792442398048L;
        protected Instances m_modelHeader;
        protected List<String> m_predictionLabels;

        public ScoringModel(Object obj) {
            setModel(obj);
        }

        public boolean isBatchPredicor() {
            if (getModel() == null) {
                return false;
            }
            return getModel() instanceof BatchPredictor;
        }

        public double[][] distributionsForInstances(Instances instances) throws Exception {
            return ((BatchPredictor) getModel()).distributionsForInstances(instances);
        }

        public void setHeader(Instances instances) {
            this.m_modelHeader = instances;
        }

        public Instances getHeader() {
            return this.m_modelHeader;
        }

        public abstract void setModel(Object obj);

        public abstract Object getModel();

        public abstract List<String> getPredictionLabels() throws DistributedWekaException;

        public abstract double[] distributionForInstance(Instance instance) throws Exception;

        public static ScoringModel createScorer(Object obj) throws Exception {
            if (obj instanceof Classifier) {
                return new ClassifierScoringModel(obj);
            }
            if (obj instanceof Clusterer) {
                return new ClustererScoringModel(obj);
            }
            return null;
        }
    }

    protected void buildAttributeMap(Instances instances, Instances instances2) throws DistributedWekaException {
        this.m_attributeMap = new int[instances.numAttributes()];
        int i = 0;
        for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
            Attribute attribute = instances.attribute(i2);
            Attribute attribute2 = instances2.attribute(attribute.name());
            if (attribute2 == null) {
                this.m_attributeMap[i2] = -1;
                this.m_missingMismatch.put(attribute.name(), "missing from incoming data");
                i++;
            } else if (attribute.type() != attribute2.type()) {
                this.m_attributeMap[i2] = -1;
                this.m_missingMismatch.put(attribute.name(), "type mismatch - model: " + Attribute.typeToString(attribute) + " != incoming: " + Attribute.typeToString(attribute2));
                i++;
            } else {
                this.m_attributeMap[i2] = attribute2.index();
            }
        }
        if (i > (instances.numAttributes() - (instances.classIndex() >= 0 ? 1 : 0)) / 2) {
            throw new DistributedWekaException("More than 50% of the attributes that the model is expecting to see are either missing or have a type mismatch in the incoming data.");
        }
    }

    public void setModel(Object obj, Instances instances, Instances instances2) throws DistributedWekaException {
        this.m_missingMismatch.clear();
        if (instances2 == null || instances == null) {
            throw new DistributedWekaException("Can't continue without a header for the model and incoming data");
        }
        try {
            this.m_isUsingStringAttributes = instances.checkForStringAttributes();
            this.m_model = ScoringModel.createScorer(obj);
            if (instances != null) {
                this.m_model.setHeader(instances);
            }
            if (this.m_model.isBatchPredicor()) {
                this.m_batchScoringData = new Instances(instances, 0);
                Environment systemWide = Environment.getSystemWide();
                String batchSize = ((BatchPredictor) obj).getBatchSize();
                if (DistributedJobConfig.isEmpty(batchSize)) {
                    this.m_batchSize = 1000;
                } else {
                    this.m_batchSize = Integer.parseInt(systemWide.substitute(batchSize));
                }
            }
            buildAttributeMap(instances, instances2);
        } catch (Exception e) {
            throw new DistributedWekaException(e);
        }
    }

    protected Instance mapIncomingFieldsToModelFields(Instance instance) {
        Instances header = this.m_model.getHeader();
        double[] dArr = new double[header.numAttributes()];
        for (int i = 0; i < header.numAttributes(); i++) {
            if (this.m_attributeMap[i] < 0) {
                dArr[i] = Utils.missingValue();
            } else {
                Attribute attribute = header.attribute(i);
                if (instance.isMissing(instance.dataset().attribute(this.m_attributeMap[i]).index())) {
                    dArr[i] = Utils.missingValue();
                } else if (attribute.isNumeric()) {
                    dArr[i] = instance.value(this.m_attributeMap[i]);
                } else if (attribute.isNominal()) {
                    int indexOfValue = attribute.indexOfValue(instance.stringValue(this.m_attributeMap[i]));
                    if (indexOfValue < 0) {
                        dArr[i] = Utils.missingValue();
                    } else {
                        dArr[i] = indexOfValue;
                    }
                } else if (attribute.isString()) {
                    dArr[i] = 0.0d;
                    attribute.setStringValue(instance.stringValue(this.m_attributeMap[i]));
                }
            }
        }
        if (header.classIndex() >= 0) {
            dArr[header.classIndex()] = Utils.missingValue();
        }
        SparseInstance sparseInstance = instance instanceof SparseInstance ? new SparseInstance(instance.weight(), dArr) : new DenseInstance(instance.weight(), dArr);
        sparseInstance.setDataset(header);
        return sparseInstance;
    }

    public double[] processInstance(Instance instance) throws DistributedWekaException {
        try {
            return this.m_model.distributionForInstance(mapIncomingFieldsToModelFields(instance));
        } catch (Exception e) {
            throw new DistributedWekaException(e);
        }
    }

    public double[][] processInstanceBatchPredictor(Instance instance) throws DistributedWekaException {
        this.m_batchScoringData.add(mapIncomingFieldsToModelFields(instance));
        if (this.m_batchScoringData.numInstances() != this.m_batchSize) {
            return (double[][]) null;
        }
        try {
            double[][] distributionsForInstances = this.m_model.distributionsForInstances(this.m_batchScoringData);
            if (distributionsForInstances.length != this.m_batchScoringData.numInstances()) {
                throw new Exception("Number of predictions did not match the number of instances in the batch");
            }
            this.m_batchScoringData.delete();
            return distributionsForInstances;
        } catch (Exception e) {
            throw new DistributedWekaException(e);
        }
    }

    public double[][] finalizeBatchPrediction() throws DistributedWekaException {
        if (this.m_batchScoringData == null || this.m_batchScoringData.numInstances() <= 0) {
            return (double[][]) null;
        }
        try {
            double[][] distributionsForInstances = this.m_model.distributionsForInstances(this.m_batchScoringData);
            this.m_batchScoringData = null;
            return distributionsForInstances;
        } catch (Exception e) {
            throw new DistributedWekaException(e);
        }
    }

    public String getMissingMismatchAttributeInfo() {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, String> entry : this.m_missingMismatch.entrySet()) {
            sb.append(entry.getKey()).append(" ").append(entry.getValue()).append("\n");
        }
        return sb.toString();
    }

    public List<String> getPredictionLabels() throws DistributedWekaException {
        return this.m_model.getPredictionLabels();
    }

    public boolean isBatchPredictor() {
        if (this.m_model == null) {
            return false;
        }
        return this.m_model.isBatchPredicor();
    }

    public boolean modelIsUsingStringAttributes() {
        return this.m_isUsingStringAttributes;
    }
}
