package org.bigml.binding;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.bigml.binding.localmodel.Prediction;
import org.bigml.binding.resources.AbstractResource;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.JSONValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bigml/binding/MultiModel.class */
public class MultiModel implements Serializable {
    private static final long serialVersionUID = 1;
    static Logger logger = LoggerFactory.getLogger(MultiModel.class.getName());
    private static String PREDICTIONS_FILE_SUFFIX = "_predictions.csv";
    private JSONArray models;
    protected JSONObject fields;
    private List<String> classNames;
    private MultiVote votes;
    private List<LocalPredictiveModel> localModels;

    public MultiModel(Object obj) throws Exception {
        this(obj, null, null);
    }

    public MultiModel(Object obj, JSONObject jSONObject, List<String> list) throws Exception {
        this.fields = null;
        this.classNames = new ArrayList();
        this.localModels = new ArrayList();
        if (obj instanceof JSONArray) {
            this.models = (JSONArray) obj;
        } else if (obj instanceof List) {
            this.models = new JSONArray();
            this.models.addAll((List) obj);
        } else {
            this.models = new JSONArray();
            this.models.add(obj);
        }
        this.classNames = list;
        Iterator it = this.models.iterator();
        while (it.hasNext()) {
            LocalPredictiveModel localPredictiveModel = new LocalPredictiveModel((JSONObject) it.next());
            if (jSONObject != null) {
                localPredictiveModel.setFields(jSONObject);
            }
            this.localModels.add(localPredictiveModel);
        }
    }

    public JSONArray listModels() {
        return this.models;
    }

    public MultiVote generateVotes(JSONObject jSONObject, MissingStrategy missingStrategy, List<String> list) throws Exception {
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        MultiVote multiVote = new MultiVote();
        for (int i = 0; i < this.localModels.size(); i++) {
            LocalPredictiveModel localPredictiveModel = this.localModels.get(i);
            Prediction predict = localPredictiveModel.predict(jSONObject, missingStrategy, null, null, true, list);
            if (localPredictiveModel.isBoosting()) {
                multiVote.boosting = true;
                predict.put("weight", localPredictiveModel.getBoosting().get("weight"));
                String str = (String) localPredictiveModel.getBoosting().get("objective_class");
                if (str != null) {
                    predict.put("class", str);
                }
            }
            multiVote.append(predict);
        }
        return multiVote;
    }

    public MultiVoteList generateVotesDistribution(JSONObject jSONObject, MissingStrategy missingStrategy, PredictionMethod predictionMethod) throws Exception {
        JSONArray predictProbability;
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        if (predictionMethod == null) {
            predictionMethod = PredictionMethod.PROBABILITY;
        }
        MultiVoteList multiVoteList = new MultiVoteList(null);
        for (int i = 0; i < this.localModels.size(); i++) {
            LocalPredictiveModel localPredictiveModel = this.localModels.get(i);
            localPredictiveModel.setClassNames(this.classNames);
            if (predictionMethod == PredictionMethod.PLURALITY) {
                ArrayList arrayList = new ArrayList();
                for (int i2 = 0; i2 < this.classNames.size(); i2++) {
                    arrayList.add(Double.valueOf(0.0d));
                }
                arrayList.set(this.classNames.indexOf((String) localPredictiveModel.predict(jSONObject, (Boolean) false, missingStrategy).get(AbstractResource.PREDICTION_PATH)), Double.valueOf(1.0d));
                multiVoteList.append(arrayList);
            } else {
                Object obj = "probability";
                if (predictionMethod == PredictionMethod.CONFIDENCE) {
                    predictProbability = localPredictiveModel.predictConfidence(jSONObject, missingStrategy);
                    obj = "confidence";
                } else {
                    predictProbability = localPredictiveModel.predictProbability(jSONObject, missingStrategy);
                }
                ArrayList arrayList2 = new ArrayList();
                Iterator it = predictProbability.iterator();
                while (it.hasNext()) {
                    arrayList2.add((Double) ((Prediction) it.next()).get(obj));
                }
                multiVoteList.append(arrayList2);
            }
        }
        return multiVoteList;
    }

    public HashMap<Object, Object> predict(JSONObject jSONObject, MissingStrategy missingStrategy, PredictionMethod predictionMethod, Map map) throws Exception {
        if (predictionMethod == null) {
            predictionMethod = PredictionMethod.PLURALITY;
        }
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        this.votes = generateVotes(jSONObject, missingStrategy, null);
        return this.votes.combine(predictionMethod, map);
    }

    public HashMap<Object, Object> predict(JSONObject jSONObject, PredictionMethod predictionMethod, Boolean bool) throws Exception {
        if (predictionMethod == null) {
            predictionMethod = PredictionMethod.PLURALITY;
        }
        if (bool == null) {
        }
        this.votes = generateVotes(jSONObject, null, null);
        return this.votes.combine(predictionMethod, null);
    }

    public HashMap<Object, Object> predict(JSONObject jSONObject, PredictionMethod predictionMethod, Boolean bool, Map map, MissingStrategy missingStrategy, Boolean bool2, Boolean bool3, Boolean bool4, Boolean bool5) throws Exception {
        if (predictionMethod == null) {
            predictionMethod = PredictionMethod.PLURALITY;
        }
        if (bool == null) {
        }
        this.votes = generateVotes(jSONObject, missingStrategy, null);
        return this.votes.combine(predictionMethod, map);
    }

    public void batchPredict(JSONArray jSONArray, String str) throws Exception {
        batchPredict(jSONArray, str, null, null, null, null, null);
    }

    public List<MultiVote> batchPredict(JSONArray jSONArray, String str, Boolean bool, MissingStrategy missingStrategy, Set<String> set, Boolean bool2, Boolean bool3) throws Exception {
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        if (bool == null) {
            bool = false;
        }
        if (bool2 == null) {
            bool2 = true;
        }
        if (bool3 == null) {
            bool3 = false;
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Iterator it = this.models.iterator();
        while (it.hasNext()) {
            JSONObject jSONObject = (JSONObject) it.next();
            i++;
            ArrayList<HashMap> arrayList2 = new ArrayList(this.models.size());
            TreeSet treeSet = new TreeSet();
            try {
                int i2 = 0;
                Iterator it2 = jSONArray.iterator();
                while (it2.hasNext()) {
                    Object next = it2.next();
                    LocalPredictiveModel localPredictiveModel = new LocalPredictiveModel(jSONObject);
                    Prediction predict = localPredictiveModel.predict((JSONObject) next, missingStrategy);
                    if (bool3.booleanValue() && localPredictiveModel.isRegression()) {
                        predict.setPrediction(predict.getMedian());
                    }
                    arrayList2.add(predict);
                    treeSet.addAll(predict.keySet());
                    Prediction prediction = new Prediction();
                    prediction.putAll(predict);
                    if (arrayList.size() <= i2) {
                        arrayList.add(new MultiVote());
                    }
                    ((MultiVote) arrayList.get(i2)).append(prediction);
                    i2++;
                }
                if (bool2.booleanValue()) {
                    ArrayList arrayList3 = new ArrayList();
                    if (set == null || set.isEmpty()) {
                        arrayList3.addAll(treeSet);
                    } else {
                        for (Object obj : treeSet) {
                            if (set.contains(obj)) {
                                arrayList3.add(obj);
                            }
                        }
                    }
                    String predictionsFileName = getPredictionsFileName(jSONObject.get("resource").toString(), str);
                    try {
                        BufferedWriter bufferedWriter = bool.booleanValue() ? new BufferedWriter(new OutputStreamWriter(new FileOutputStream(predictionsFileName, true), "UTF-8")) : new BufferedWriter(new OutputStreamWriter(new FileOutputStream(predictionsFileName), "UTF-8"));
                        CSVPrinter print = CSVFormat.DEFAULT.withHeader((String[]) arrayList3.toArray(new String[arrayList3.size()])).print(bufferedWriter);
                        try {
                            for (HashMap hashMap : arrayList2) {
                                Object[] objArr = new Object[arrayList3.size()];
                                for (int i3 = 0; i3 < arrayList3.size(); i3++) {
                                    objArr[i3] = hashMap.get(arrayList3.get(i3));
                                }
                                print.printRecord(objArr);
                            }
                            try {
                                bufferedWriter.flush();
                                bufferedWriter.close();
                            } catch (IOException e) {
                                throw new Exception("Error while flushing/closing fileWriter !!!");
                            }
                        } catch (Exception e2) {
                            throw new Exception("Error generating the CSV !!!");
                        }
                    } catch (IOException e3) {
                        throw new Exception(String.format("Cannot find %s directory.", str));
                    }
                }
            } catch (Exception e4) {
                throw new Exception("Error generating the CSV !!!", e4);
            }
        }
        return arrayList;
    }

    public List<MultiVote> batchVotes(String str, Locale locale) throws Exception {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(getPredictionsFileName(((JSONObject) it.next()).get("resource").toString(), str));
        }
        return readVotes(arrayList, new LocalPredictiveModel((JSONObject) this.models.get(0)), locale);
    }

    public List<MultiVote> readVotes(List<String> list, PredictionConverter predictionConverter, Locale locale) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(list.get(i)), "UTF-8");
            CSVParser cSVParser = new CSVParser(inputStreamReader, CSVFormat.EXCEL.withHeader(new String[0]));
            try {
                Map headerMap = cSVParser.getHeaderMap();
                int i2 = 0;
                Iterator it = cSVParser.iterator();
                while (it.hasNext()) {
                    CSVRecord cSVRecord = (CSVRecord) it.next();
                    if (i2 == arrayList.size()) {
                        arrayList.add(new MultiVote());
                    }
                    HashMap hashMap = new HashMap(4);
                    for (String str : headerMap.keySet()) {
                        Object obj = cSVRecord.get(str);
                        if (obj != null && obj.toString().length() > 0) {
                            if (AbstractResource.PREDICTION_PATH.equals(str)) {
                                obj = predictionConverter.toPrediction((String) obj, locale);
                            } else if ("order".equals(str)) {
                                obj = Integer.valueOf(Integer.parseInt(obj.toString()));
                            } else if ("distribution".equals(str)) {
                                obj = JSONValue.parse(obj.toString());
                            } else if ("instances".equals(str)) {
                                obj = Long.valueOf(Long.parseLong(obj.toString()));
                            } else if ("confidence".equals(str)) {
                                obj = Double.valueOf(Double.parseDouble(obj.toString()));
                            }
                            hashMap.put(str, obj);
                        }
                    }
                    ArrayList arrayList2 = new ArrayList(hashMap.size());
                    ArrayList arrayList3 = new ArrayList(hashMap.size());
                    for (String str2 : hashMap.keySet()) {
                        arrayList2.add(str2);
                        arrayList3.add(hashMap.get(str2));
                    }
                    ((MultiVote) arrayList.get(i2)).appendRow(arrayList3, arrayList2);
                    i2++;
                }
            } finally {
                cSVParser.close();
                inputStreamReader.close();
            }
        }
        return arrayList;
    }

    protected String getPredictionsFileName(JSONObject jSONObject, String str) {
        return getPredictionsFileName((String) jSONObject.get("resource"), str);
    }

    protected String getPredictionsFileName(String str, String str2) {
        return String.format("%s%s%s%s", str2, File.separator, str.replace('/', '_'), PREDICTIONS_FILE_SUFFIX);
    }
}
