package org.bigml.binding;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bigml/binding/LocalFusion.class */
public class LocalFusion extends ModelFields implements SupervisedModelInterface {
    private static final long serialVersionUID = 1;
    static String FUSION_RE = "^fusion/[a-f,0-9]{24}$";
    private static final String[] OPERATING_POINT_KINDS = {"probability"};
    private static final String[] LOCAL_SUPERVISED = {AbstractResource.MODEL_PATH, AbstractResource.ENSEMBLE_PATH, AbstractResource.LOGISTICREGRESSION_PATH, AbstractResource.DEEPNET_PATH, AbstractResource.FUSION_PATH};
    static Logger logger = LoggerFactory.getLogger(LocalFusion.class.getName());
    private String fusionId;
    private String objectiveField;
    private JSONArray modelsIds;
    private List<Double> weights;
    private final List<JSONArray> modelsSplit;
    private Boolean regression;
    private List<String> classNames;
    private Boolean missingNumerics;

    public LocalFusion(JSONObject jSONObject) throws Exception {
        this(null, jSONObject, null);
    }

    public LocalFusion(BigMLClient bigMLClient, JSONObject jSONObject, Integer num) throws Exception {
        super(bigMLClient, jSONObject);
        String str;
        this.objectiveField = null;
        this.weights = new ArrayList();
        this.modelsSplit = new ArrayList();
        this.regression = false;
        this.classNames = new ArrayList();
        this.missingNumerics = true;
        JSONObject jSONObject2 = this.model;
        this.fusionId = (String) jSONObject2.get("resource");
        if (!jSONObject2.containsKey(AbstractResource.FUSION_PATH) || !(jSONObject2.get(AbstractResource.FUSION_PATH) instanceof JSONObject)) {
            throw new Exception(String.format("Cannot create the Fusion instance. Could not find the 'fusion' key in the resource:\n\n%s", jSONObject2));
        }
        JSONObject jSONObject3 = (JSONObject) Utils.getJSONObject(jSONObject2, "status");
        if (jSONObject3 == null || !jSONObject3.containsKey("code") || AbstractResource.FINISHED != ((Number) jSONObject3.get("code")).intValue()) {
            throw new Exception("The Fusion isn't finished yet");
        }
        JSONObject jSONObject4 = (JSONObject) Utils.getJSONObject(jSONObject2, AbstractResource.FUSION_PATH);
        this.modelsIds = new JSONArray();
        Iterator it = ((JSONArray) jSONObject2.get("models")).iterator();
        while (it.hasNext()) {
            Object next = it.next();
            if (next instanceof String) {
                str = (String) next;
            } else {
                str = (String) ((JSONObject) next).get("id");
                try {
                    this.weights.add(Double.valueOf(((Number) ((JSONObject) next).get("weight")).doubleValue()));
                } catch (Exception e) {
                    this.weights = new ArrayList();
                }
            }
            this.modelsIds.add(str);
            if (!Arrays.asList(LOCAL_SUPERVISED).contains(str.split("/")[0])) {
                throw new IllegalArgumentException(String.format("The resource %s has not an allowed supervised model type.", OPERATING_POINT_KINDS));
            }
        }
        this.missingNumerics = (Boolean) Utils.getJSONObject(jSONObject2, "missing_numerics", true);
        JSONObject jSONObject5 = (JSONObject) Utils.getJSONObject(jSONObject4, "fields", new JSONObject());
        super.initialize(jSONObject5, null, null, null, true, true, true);
        this.objectiveField = (String) Utils.getJSONObject(jSONObject2, "objective_field");
        int size = this.modelsIds.size();
        if (num != null) {
            for (int i : Utils.getRange(0, size, num.intValue())) {
                if (i + num.intValue() <= size) {
                    JSONArray jSONArray = new JSONArray();
                    jSONArray.addAll(this.modelsIds.subList(i, i + num.intValue()));
                    this.modelsSplit.add(jSONArray);
                }
            }
        } else {
            this.modelsSplit.add(this.modelsIds);
        }
        this.regression = Boolean.valueOf(Constants.OPTYPE_NUMERIC.equals((String) Utils.getJSONObject(jSONObject5, this.objectiveField + ".optype")));
        if (this.regression.booleanValue()) {
            return;
        }
        Iterator it2 = ((JSONArray) Utils.getJSONObject((JSONObject) jSONObject5.get(this.objectiveField), "summary.categories", new JSONArray())).iterator();
        while (it2.hasNext()) {
            this.classNames.add((String) ((JSONArray) it2.next()).get(0));
        }
        Collections.sort(this.classNames);
    }

    @Override // org.bigml.binding.ModelFields
    public String getModelIdRe() {
        return FUSION_RE;
    }

    @Override // org.bigml.binding.ModelFields
    public JSONObject getBigMLModel(String str) {
        return this.bigmlClient.getFusion(str);
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public String getResourceId() {
        return this.fusionId;
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public List<String> getClassNames() {
        return this.classNames;
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public JSONArray predictProbability(JSONObject jSONObject, MissingStrategy missingStrategy) throws Exception {
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        MultiVoteList multiVoteList = new MultiVoteList(null);
        if (!this.missingNumerics.booleanValue()) {
            Utils.checkNoMissingNumerics(jSONObject, this.fields, null);
        }
        BigMLClient bigMLClient = new BigMLClient();
        for (JSONArray jSONArray : this.modelsSplit) {
            MultiVoteList multiVoteList2 = new MultiVoteList(null);
            ArrayList<SupervisedModelInterface> arrayList = new ArrayList();
            Iterator it = jSONArray.iterator();
            while (it.hasNext()) {
                Object next = it.next();
                String str = ((String) next).split("/")[0];
                if (AbstractResource.MODEL_PATH.equals(str)) {
                    arrayList.add(new LocalPredictiveModel(bigMLClient.getModel((String) next)));
                }
                if (AbstractResource.ENSEMBLE_PATH.equals(str)) {
                    arrayList.add(new LocalEnsemble(bigMLClient.getEnsemble((String) next)));
                }
                if (AbstractResource.LOGISTICREGRESSION_PATH.equals(str)) {
                    arrayList.add(new LocalLogisticRegression(bigMLClient.getLogisticRegression((String) next)));
                }
                if (AbstractResource.DEEPNET_PATH.equals(str)) {
                    arrayList.add(new LocalDeepnet(bigMLClient.getDeepnet((String) next)));
                }
                if (AbstractResource.FUSION_PATH.equals(str)) {
                    arrayList.add(new LocalFusion(bigMLClient.getFusion((String) next)));
                }
            }
            for (SupervisedModelInterface supervisedModelInterface : arrayList) {
                try {
                    JSONArray predictProbability = supervisedModelInterface.predictProbability(jSONObject, missingStrategy);
                    List<Double> arrayList2 = new ArrayList();
                    Iterator it2 = predictProbability.iterator();
                    while (it2.hasNext()) {
                        arrayList2.add((Double) ((JSONObject) it2.next()).get("probability"));
                    }
                    if (!this.weights.isEmpty()) {
                        arrayList2 = weight(arrayList2, supervisedModelInterface.getResourceId());
                    }
                    if (!this.regression.booleanValue() && !this.classNames.equals(supervisedModelInterface.getClassNames())) {
                        try {
                            arrayList2 = rearrangePrediction(supervisedModelInterface.getClassNames(), this.classNames, arrayList2);
                        } catch (Exception e) {
                        }
                    }
                    multiVoteList2.append(arrayList2);
                } catch (Exception e2) {
                }
            }
            multiVoteList.extend(multiVoteList2);
        }
        JSONArray jSONArray2 = new JSONArray();
        if (this.regression.booleanValue()) {
            double d = 1.0d;
            if (!this.weights.isEmpty()) {
                d = 0.0d;
                Iterator<Double> it3 = this.weights.iterator();
                while (it3.hasNext()) {
                    d += it3.next().doubleValue();
                }
            }
            double d2 = 0.0d;
            Iterator<List<Double>> it4 = multiVoteList.predictions.iterator();
            while (it4.hasNext()) {
                Iterator<Double> it5 = it4.next().iterator();
                while (it5.hasNext()) {
                    d2 += it5.next().doubleValue();
                }
            }
            float floatValue = Double.valueOf(multiVoteList.predictions.size() * d).floatValue();
            JSONObject jSONObject2 = new JSONObject();
            jSONObject2.put(AbstractResource.PREDICTION_PATH, Double.valueOf(d2 / floatValue));
            jSONArray2.add(jSONObject2);
        } else {
            List<Double> combineToDistribution = multiVoteList.combineToDistribution(true);
            for (int i = 0; i < this.classNames.size(); i++) {
                JSONObject jSONObject3 = new JSONObject();
                jSONObject3.put(AbstractResource.PREDICTION_PATH, this.classNames.get(i));
                jSONObject3.put("probability", combineToDistribution.get(i));
                jSONArray2.add(jSONObject3);
            }
        }
        return jSONArray2;
    }

    private List<Double> weight(List<Double> list, String str) {
        Iterator<Double> it = list.iterator();
        while (it.hasNext()) {
            Double.valueOf(it.next().doubleValue() * this.weights.get(this.modelsIds.indexOf(str)).doubleValue());
        }
        return list;
    }

    private List<Double> rearrangePrediction(List<String> list, List<String> list2, List<Double> list3) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list2.iterator();
        while (it.hasNext()) {
            int indexOf = list.indexOf(it.next());
            if (indexOf > -1) {
                arrayList.add(list3.get(indexOf));
            } else {
                arrayList.add(Double.valueOf(0.0d));
            }
        }
        return arrayList;
    }

    private HashMap<String, Object> predictOperating(JSONObject jSONObject, MissingStrategy missingStrategy, JSONObject jSONObject2) throws Exception {
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        Object[] parseOperatingPoint = Utils.parseOperatingPoint(jSONObject2, OPERATING_POINT_KINDS, this.classNames);
        String str = (String) parseOperatingPoint[0];
        Double d = (Double) parseOperatingPoint[1];
        String str2 = (String) parseOperatingPoint[2];
        if (!Arrays.asList(OPERATING_POINT_KINDS).contains(str)) {
            throw new IllegalArgumentException(String.format("Allowed operating kinds are %", OPERATING_POINT_KINDS));
        }
        JSONArray predictProbability = predictProbability(jSONObject, missingStrategy);
        Iterator it = predictProbability.iterator();
        while (it.hasNext()) {
            HashMap<String, Object> hashMap = (HashMap) it.next();
            String str3 = (String) hashMap.get("category");
            if (str3 == null) {
                str3 = (String) hashMap.get(AbstractResource.PREDICTION_PATH);
            }
            if (str3.equals(str2) && ((Double) hashMap.get(str)).doubleValue() > d.doubleValue()) {
                return hashMap;
            }
        }
        HashMap hashMap2 = (HashMap) predictProbability.get(0);
        String str4 = (String) hashMap2.get("category");
        if (str4 == null) {
            str4 = (String) hashMap2.get(AbstractResource.PREDICTION_PATH);
        }
        if (str4.equals(str2)) {
            hashMap2 = (JSONObject) predictProbability.get(1);
        }
        if (hashMap2.get("category") != null) {
            hashMap2.put(AbstractResource.PREDICTION_PATH, hashMap2.get("category"));
            hashMap2.remove("category");
        }
        return hashMap2;
    }

    public HashMap<String, Object> predict(JSONObject jSONObject, MissingStrategy missingStrategy, JSONObject jSONObject2, Boolean bool) throws Exception {
        if (missingStrategy == null) {
            missingStrategy = MissingStrategy.LAST_PREDICTION;
        }
        if (bool == null) {
            bool = false;
        }
        JSONObject filterInputData = filterInputData(jSONObject, bool);
        List list = (List) filterInputData.get("unusedFields");
        JSONObject jSONObject3 = (JSONObject) filterInputData.get("newInputData");
        if (!this.missingNumerics.booleanValue()) {
            Utils.checkNoMissingNumerics(jSONObject3, this.fields, null);
        }
        Utils.cast(jSONObject3, this.fields);
        if (jSONObject2 != null) {
            if (this.regression.booleanValue()) {
                throw new IllegalArgumentException("The operating_point argument can only be used in classifications.");
            }
            return predictOperating(jSONObject3, missingStrategy, jSONObject2);
        }
        JSONArray predictProbability = predictProbability(jSONObject3, missingStrategy);
        if (!this.regression.booleanValue()) {
            Utils.sortPredictions(predictProbability, "probability", AbstractResource.PREDICTION_PATH);
        }
        HashMap<String, Object> hashMap = (HashMap) predictProbability.get(0);
        if (bool.booleanValue()) {
            hashMap.put("unused_fields", list);
        }
        return hashMap;
    }
}
