package org.bigml.binding;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bigml/binding/LocalLogisticRegression.class */
public class LocalLogisticRegression extends ModelFields implements SupervisedModelInterface {
    private static final long serialVersionUID = 1;
    static String LOGISTICREGRESSION_RE = "^logisticregression/[a-f,0-9]{24}$";
    static HashMap<String, String> EXPANSION_ATTRIBUTES = new HashMap<>();
    protected static final String[] OPTIONAL_FIELDS;
    static Logger logger;
    private String logisticRegressionId;
    private JSONObject datasetFieldTypes;
    private JSONArray inputFields;
    private String objectiveField;
    private JSONArray objectiveFields;
    private JSONObject coefficients;
    private Boolean bias;
    private Boolean normalize;
    private Boolean balanceFields;
    private JSONObject fieldCodings;
    private List<String> classNames;
    private String weightField;
    private String defaultNumericValue;

    public LocalLogisticRegression(JSONObject jSONObject) throws Exception {
        this(null, jSONObject);
    }

    public LocalLogisticRegression(BigMLClient bigMLClient, JSONObject jSONObject) throws Exception {
        super(bigMLClient, jSONObject);
        this.datasetFieldTypes = null;
        this.inputFields = null;
        this.objectiveField = null;
        this.objectiveFields = null;
        this.coefficients = null;
        this.classNames = new ArrayList();
        this.defaultNumericValue = null;
        JSONObject jSONObject2 = this.model;
        this.logisticRegressionId = (String) jSONObject2.get("resource");
        this.datasetFieldTypes = (JSONObject) Utils.getJSONObject(jSONObject2, "dataset_field_types");
        this.inputFields = (JSONArray) Utils.getJSONObject(jSONObject2, "input_fields");
        this.objectiveField = (String) Utils.getJSONObject(jSONObject2, "objective_field");
        this.objectiveFields = (JSONArray) Utils.getJSONObject(jSONObject2, "objective_fields");
        this.weightField = (String) Utils.getJSONObject(jSONObject2, "weight_field");
        if (this.datasetFieldTypes == null || this.inputFields == null || (this.objectiveField == null && this.objectiveFields == null)) {
            throw new Exception("Failed to find the logistic regression expected JSON structure. Check your arguments.");
        }
        if (!jSONObject2.containsKey("logistic_regression") || !(jSONObject2.get("logistic_regression") instanceof JSONObject)) {
            throw new Exception(String.format("Cannot create the LogisticRegression instance. Could not find the 'logistic_regression' key in the resource:\n\n%s", jSONObject2));
        }
        JSONObject jSONObject3 = (JSONObject) Utils.getJSONObject(jSONObject2, "status");
        if (jSONObject3 == null || !jSONObject3.containsKey("code") || AbstractResource.FINISHED != ((Number) jSONObject3.get("code")).intValue()) {
            throw new Exception("The logistic regression isn't finished yet");
        }
        JSONObject jSONObject4 = (JSONObject) Utils.getJSONObject(jSONObject2, "logistic_regression");
        JSONArray jSONArray = (JSONArray) Utils.getJSONObject(jSONObject4, "coefficients", new JSONArray());
        if (jSONArray.get(0) instanceof String) {
            throw new Exception("Detected old format of logistic regression detected.");
        }
        JSONObject jSONObject5 = (JSONObject) Utils.getJSONObject(jSONObject4, "fields", new JSONObject());
        this.defaultNumericValue = (String) jSONObject2.get("default_numeric_value");
        if (this.inputFields == null) {
            this.inputFields = new JSONArray();
            String[] strArr = new String[jSONObject5.values().size()];
            for (Object obj : jSONObject5.keySet()) {
                strArr[((Number) Utils.getJSONObject(jSONObject5, obj + ".column_number")).intValue()] = (String) obj;
            }
            this.inputFields.addAll(Arrays.asList(strArr));
        }
        this.coefficients = new JSONObject();
        for (int i = 0; i < jSONArray.size(); i++) {
            JSONArray jSONArray2 = (JSONArray) jSONArray.get(i);
            this.coefficients.put((String) jSONArray2.get(0), (JSONArray) jSONArray2.get(1));
        }
        this.bias = (Boolean) Utils.getJSONObject(jSONObject4, "bias", true);
        this.normalize = (Boolean) Utils.getJSONObject(jSONObject4, "normalize");
        this.balanceFields = (Boolean) Utils.getJSONObject(jSONObject4, "balance_fields");
        this.missingNumerics = (Boolean) Utils.getJSONObject(jSONObject4, "missing_numerics");
        super.initialize(jSONObject5, null, null, null, true, true, true);
        Object jSONObject6 = Utils.getJSONObject(jSONObject4, "field_codings");
        if (jSONObject6 == null || !(jSONObject6 instanceof JSONArray)) {
            this.fieldCodings = (JSONObject) Utils.getJSONObject(jSONObject4, "field_codings", new JSONObject());
        } else {
            formatFieldCodings((JSONArray) jSONObject6);
        }
        for (String str : this.fieldCodings.keySet()) {
            if (!jSONObject5.containsKey(str) && this.invertedFields.containsKey(str)) {
                ((JSONObject) this.fieldCodings.get(str)).put(this.invertedFields.get(str), this.fieldCodings.get(str));
                this.fieldCodings.remove(str);
            }
        }
        JSONArray jSONArray3 = (JSONArray) Utils.getJSONObject((JSONObject) jSONObject5.get(this.objectiveField), "summary.categories", new JSONArray());
        if (this.coefficients.keySet().size() > jSONArray3.size()) {
            this.classNames.add("");
        }
        Iterator it = jSONArray3.iterator();
        while (it.hasNext()) {
            this.classNames.add((String) ((JSONArray) it.next()).get(0));
        }
    }

    @Override // org.bigml.binding.ModelFields
    public String getModelIdRe() {
        return LOGISTICREGRESSION_RE;
    }

    @Override // org.bigml.binding.ModelFields
    public JSONObject getBigMLModel(String str) {
        return this.bigmlClient.getLogisticRegression(str);
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public String getResourceId() {
        return this.logisticRegressionId;
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public List<String> getClassNames() {
        return this.classNames;
    }

    public JSONArray predictProbability(JSONObject jSONObject) {
        try {
            return predictProbability(jSONObject, null);
        } catch (Exception e) {
            return null;
        }
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public JSONArray predictProbability(JSONObject jSONObject, MissingStrategy missingStrategy) throws Exception {
        JSONArray jSONArray = (JSONArray) predict(jSONObject, null, null, true).get("distribution");
        Utils.sortPredictions(jSONArray, "probability", AbstractResource.PREDICTION_PATH);
        return jSONArray;
    }

    private HashMap<String, Object> predictOperating(JSONObject jSONObject, JSONObject jSONObject2) {
        Object[] parseOperatingPoint = Utils.parseOperatingPoint(jSONObject2, new String[]{"probability"}, this.classNames);
        String str = (String) parseOperatingPoint[0];
        Double d = (Double) parseOperatingPoint[1];
        String str2 = (String) parseOperatingPoint[2];
        JSONArray predictProbability = predictProbability(jSONObject);
        Iterator it = predictProbability.iterator();
        while (it.hasNext()) {
            HashMap<String, Object> hashMap = (HashMap) it.next();
            if (((String) hashMap.get(AbstractResource.PREDICTION_PATH)).equals(str2) && ((Double) hashMap.get(str)).doubleValue() > d.doubleValue()) {
                return hashMap;
            }
        }
        return (HashMap) predictProbability.get(0);
    }

    private HashMap<String, Object> predictOperatingKind(JSONObject jSONObject, String str) {
        if (str.toLowerCase().equals("probability")) {
            return (HashMap) predictProbability(jSONObject).get(0);
        }
        throw new IllegalArgumentException("Only probability is allowed as operating kind for logistic regressions.");
    }

    public HashMap<String, Object> predict(JSONObject jSONObject, JSONObject jSONObject2, String str, Boolean bool) {
        if (bool == null) {
            bool = false;
        }
        JSONObject filterInputData = filterInputData(jSONObject, bool);
        List list = (List) filterInputData.get("unusedFields");
        JSONObject jSONObject3 = (JSONObject) filterInputData.get("newInputData");
        Utils.cast(jSONObject3, this.fields);
        if (jSONObject2 != null) {
            return predictOperating(jSONObject3, jSONObject2);
        }
        if (str != null) {
            return predictOperatingKind(jSONObject3, str);
        }
        if (!this.missingNumerics.booleanValue()) {
            Utils.checkNoMissingNumerics(jSONObject3, this.fields, this.weightField);
        }
        if (this.balanceFields != null && this.balanceFields.booleanValue()) {
            balanceInput(jSONObject3, this.fields);
        }
        Map<String, Object> uniqueTerms = uniqueTerms(jSONObject3);
        JSONObject jSONObject4 = new JSONObject();
        double d = 0.0d;
        for (String str2 : this.coefficients.keySet()) {
            double categoryProbability = categoryProbability(jSONObject3, uniqueTerms, str2);
            JSONArray jSONArray = (JSONArray) this.categories.get(this.objectiveField);
            int indexOf = jSONArray.indexOf(str2);
            if (indexOf == -1 && str2.equals("")) {
                indexOf = jSONArray.size();
            }
            JSONObject jSONObject5 = new JSONObject();
            jSONObject5.put(AbstractResource.PREDICTION_PATH, str2);
            jSONObject5.put("probability", Double.valueOf(categoryProbability));
            jSONObject5.put("order", Integer.valueOf(indexOf));
            jSONObject4.put(str2, jSONObject5);
            d += categoryProbability;
        }
        Iterator it = jSONObject4.keySet().iterator();
        while (it.hasNext()) {
            JSONObject jSONObject6 = (JSONObject) jSONObject4.get(it.next());
            jSONObject6.put("probability", Double.valueOf(Utils.roundOff(((Number) jSONObject6.get("probability")).doubleValue() / d, 5)));
        }
        JSONArray jSONArray2 = new JSONArray();
        Iterator it2 = jSONObject4.keySet().iterator();
        while (it2.hasNext()) {
            JSONObject jSONObject7 = (JSONObject) jSONObject4.get(it2.next());
            jSONObject7.remove("order");
            jSONArray2.add(jSONObject7);
        }
        Utils.sortPredictions(jSONArray2, "probability", AbstractResource.PREDICTION_PATH);
        JSONObject jSONObject8 = (JSONObject) jSONArray2.get(0);
        HashMap<String, Object> hashMap = new HashMap<>();
        hashMap.put(AbstractResource.PREDICTION_PATH, (String) jSONObject8.get(AbstractResource.PREDICTION_PATH));
        hashMap.put("probability", (Double) jSONObject8.get("probability"));
        hashMap.put("distribution", jSONArray2);
        if (bool.booleanValue()) {
            hashMap.put("unused_fields", list);
        }
        return hashMap;
    }

    private double categoryProbability(JSONObject jSONObject, Map<String, Object> map, String str) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (String str2 : jSONObject.keySet()) {
            JSONArray coefficients = getCoefficients(str, str2);
            double doubleValue = ((Number) jSONObject.get(str2)).doubleValue();
            d += ((Number) coefficients.get(0)).doubleValue() * doubleValue;
            if (this.normalize.booleanValue()) {
                d2 += Math.pow(doubleValue, 2.0d);
            }
        }
        for (String str3 : map.keySet()) {
            if (this.inputFields.contains(str3)) {
                Map map2 = (Map) map.get(str3);
                JSONArray coefficients2 = getCoefficients(str, str3);
                for (Object obj : map2.keySet()) {
                    int intValue = ((Number) map2.get(obj)).intValue();
                    try {
                        boolean z = true;
                        Integer num = null;
                        if (this.tagClouds.containsKey(str3)) {
                            num = Integer.valueOf(this.tagClouds.get(str3).indexOf(obj));
                        } else if (this.items.containsKey(str3)) {
                            num = Integer.valueOf(this.items.get(str3).indexOf(obj));
                        } else {
                            JSONObject jSONObject2 = (JSONObject) this.fieldCodings.get(str3);
                            if (this.categories.containsKey(str3) && (!this.fieldCodings.containsKey(str3) || "dummy".equals((String) jSONObject2.keySet().toArray()[0]))) {
                                num = Integer.valueOf(((JSONArray) this.categories.get(str3)).indexOf(obj));
                            } else if (this.categories.containsKey(str3)) {
                                z = false;
                                num = Integer.valueOf(((JSONArray) this.categories.get(str3)).indexOf(obj));
                                int i = 0;
                                Iterator it = ((JSONArray) jSONObject2.values().toArray()[0]).iterator();
                                while (it.hasNext()) {
                                    d += ((Number) coefficients2.get(i)).doubleValue() * ((Number) ((JSONArray) it.next()).get(num.intValue())).doubleValue() * intValue;
                                    i++;
                                }
                            }
                        }
                        if (z) {
                            d += ((Number) coefficients2.get(num.intValue())).doubleValue() * intValue;
                        }
                        d2 += Math.pow(intValue, 2.0d);
                    } catch (Exception e) {
                    }
                }
            }
        }
        Iterator it2 = this.inputFields.iterator();
        while (it2.hasNext()) {
            String str4 = (String) it2.next();
            boolean z2 = false;
            JSONArray coefficients3 = getCoefficients(str, str4);
            try {
                if (!this.numericFields.containsKey(str4) || jSONObject.containsKey(str4)) {
                    boolean z3 = !map.containsKey(str4) || map.get(str4) == null || ((HashMap) map.get(str4)).keySet().size() == 0;
                    if (this.tagClouds.containsKey(str4) && z3) {
                        d += ((Number) coefficients3.get(this.tagClouds.get(str4).size())).doubleValue();
                        z2 = true;
                    } else if (this.items.containsKey(str4) && z3) {
                        d += ((Number) coefficients3.get(this.items.get(str4).size())).doubleValue();
                        z2 = true;
                    } else if (this.categories.containsKey(str4) && !this.objectiveField.equals(str4) && !map.containsKey(str4)) {
                        JSONObject jSONObject3 = (JSONObject) this.fieldCodings.get(str4);
                        if (!this.fieldCodings.containsKey(str4) || "dummy".equals((String) jSONObject3.keySet().toArray()[0])) {
                            d += ((Number) coefficients3.get(((List) this.categories.get(str4)).size())).doubleValue();
                        } else {
                            int i2 = 0;
                            Iterator it3 = ((JSONArray) jSONObject3.values().toArray()[0]).iterator();
                            while (it3.hasNext()) {
                                JSONArray jSONArray = (JSONArray) it3.next();
                                d += ((Number) coefficients3.get(i2)).doubleValue() * ((Number) jSONArray.get(jSONArray.size() - 1)).doubleValue();
                                i2++;
                            }
                        }
                        z2 = true;
                    }
                } else {
                    d += (coefficients3.size() == 1 ? (Number) coefficients3.get(0) : (Number) coefficients3.get(1)).doubleValue();
                    z2 = true;
                }
            } catch (Exception e2) {
                e2.printStackTrace();
            }
            if (z2 && this.normalize.booleanValue()) {
                d2 += 1.0d;
            }
        }
        JSONArray jSONArray2 = (JSONArray) this.coefficients.get(str);
        double doubleValue2 = d + ((Number) ((JSONArray) jSONArray2.get(jSONArray2.size() - 1)).get(0)).doubleValue();
        if (this.bias.booleanValue()) {
            d2 += 1.0d;
        }
        if (this.normalize.booleanValue()) {
            try {
                doubleValue2 /= Math.sqrt(d2);
            } catch (Exception e3) {
                doubleValue2 = 0.0d;
            }
        }
        try {
            doubleValue2 = 1.0d / (1.0d + Math.exp(-doubleValue2));
        } catch (Exception e4) {
            doubleValue2 = doubleValue2 < 0.0d ? 0.0d : 1.0d;
        }
        return Utils.roundOff(doubleValue2, 5);
    }

    private void balanceInput(JSONObject jSONObject, JSONObject jSONObject2) {
        for (Object obj : jSONObject.keySet()) {
            JSONObject jSONObject3 = (JSONObject) jSONObject2.get(obj);
            if (Constants.OPTYPE_NUMERIC.equals(jSONObject3.get("optype"))) {
                JSONObject jSONObject4 = (JSONObject) jSONObject3.get("summary");
                Double valueOf = Double.valueOf(((Number) Utils.getJSONObject(jSONObject4, "mean", 0)).doubleValue());
                Double valueOf2 = Double.valueOf(((Number) Utils.getJSONObject(jSONObject4, "standard_deviation", 0)).doubleValue());
                double doubleValue = ((Number) jSONObject.get(obj)).doubleValue();
                if (valueOf2.doubleValue() <= 0.0d) {
                    jSONObject.put(obj, Double.valueOf(doubleValue - valueOf.doubleValue()));
                } else {
                    jSONObject.put(obj, Double.valueOf((doubleValue - valueOf.doubleValue()) / valueOf2.doubleValue()));
                }
            }
        }
    }

    private JSONArray getCoefficients(String str, String str2) {
        return (JSONArray) ((JSONArray) this.coefficients.get(str)).get(this.inputFields.indexOf(str2));
    }

    private void formatFieldCodings(JSONArray jSONArray) {
        this.fieldCodings = new JSONObject();
        for (int i = 0; i < jSONArray.size(); i++) {
            JSONObject jSONObject = (JSONObject) jSONArray.get(i);
            String str = (String) jSONObject.get("field");
            String str2 = (String) jSONObject.get("coding");
            JSONObject jSONObject2 = new JSONObject();
            if (str2.equals("dummy")) {
                jSONObject2.put(str2, jSONObject.get("dummy_class"));
            } else {
                jSONObject2.put(str2, jSONObject.get("coefficients"));
            }
            this.fieldCodings.put(str, jSONObject2);
        }
    }

    static {
        EXPANSION_ATTRIBUTES.put(Constants.OPTYPE_CATEGORICAL, "categories");
        EXPANSION_ATTRIBUTES.put(Constants.OPTYPE_TEXT, "tag_cloud");
        EXPANSION_ATTRIBUTES.put("items", "items");
        OPTIONAL_FIELDS = new String[]{Constants.OPTYPE_CATEGORICAL, Constants.OPTYPE_TEXT, "items", Constants.OPTYPE_DATETIME};
        logger = LoggerFactory.getLogger(LocalLogisticRegression.class.getName());
    }
}
