package org.bigml.binding;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.binding.laminar.MathOps;
import org.bigml.binding.laminar.Preprocess;
import org.bigml.binding.resources.AbstractResource;
import org.bigml.binding.utils.Utils;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bigml/binding/LocalDeepnet.class */
public class LocalDeepnet extends ModelFields implements SupervisedModelInterface {
    private static final long serialVersionUID = 1;
    private static String DEEPNET_RE = "^deepnet/[a-f,0-9]{24}$";
    static Logger logger = LoggerFactory.getLogger(LocalDeepnet.class.getName());
    private String deepnetId;
    private JSONArray inputFields;
    private String objectiveField;
    private JSONArray objectiveFields;
    private Boolean regression;
    private List<String> classNames;
    private JSONObject network;
    private JSONArray networks;
    private JSONArray preprocess;

    public LocalDeepnet(JSONObject jSONObject) throws Exception {
        super((JSONObject) Utils.getJSONObject(jSONObject, "deepnet.fields", new JSONObject()));
        this.inputFields = null;
        this.objectiveField = null;
        this.objectiveFields = null;
        this.regression = false;
        this.classNames = new ArrayList();
        this.network = null;
        this.networks = null;
        this.preprocess = null;
        if (!checkModelFields(jSONObject)) {
            this.deepnetId = (String) jSONObject.get("resource");
            if (!this.deepnetId.matches(DEEPNET_RE)) {
                throw new Exception(this.deepnetId + " is not a valid resource ID.");
            }
        }
        if (!jSONObject.containsKey("resource") || jSONObject.get("resource") == null) {
            jSONObject = new BigMLClient(null, null, BigMLClient.STORAGE).getLogisticRegression(this.deepnetId);
            if (((String) jSONObject.get("resource")) == null) {
                throw new Exception(this.deepnetId + " is not a valid resource ID.");
            }
        }
        if (jSONObject.containsKey("object") && (jSONObject.get("object") instanceof JSONObject)) {
            jSONObject = (JSONObject) jSONObject.get("object");
        }
        this.deepnetId = (String) jSONObject.get("resource");
        this.inputFields = (JSONArray) Utils.getJSONObject(jSONObject, "input_fields");
        if (!jSONObject.containsKey(AbstractResource.DEEPNET_PATH) || !(jSONObject.get(AbstractResource.DEEPNET_PATH) instanceof JSONObject)) {
            throw new Exception(String.format("Cannot create the Deepnet instance. Could not find the 'deepnet' key in the resource:\n\n%s", jSONObject));
        }
        JSONObject jSONObject2 = (JSONObject) Utils.getJSONObject(jSONObject, "status");
        if (jSONObject2 == null || !jSONObject2.containsKey("code") || AbstractResource.FINISHED != ((Number) jSONObject2.get("code")).intValue()) {
            throw new Exception("The deepnet isn't finished yet");
        }
        this.objectiveField = (String) Utils.getJSONObject(jSONObject, "objective_field");
        this.objectiveFields = (JSONArray) Utils.getJSONObject(jSONObject, "objective_fields");
        JSONObject jSONObject3 = (JSONObject) Utils.getJSONObject(jSONObject, AbstractResource.DEEPNET_PATH);
        JSONObject jSONObject4 = (JSONObject) Utils.getJSONObject(jSONObject3, "fields", new JSONObject());
        this.objectiveField = this.objectiveField != null ? this.objectiveField : (String) this.objectiveFields.get(0);
        super.initialize(jSONObject4, this.objectiveField, null, null, true, true, false);
        this.regression = Boolean.valueOf(((String) Utils.getJSONObject(jSONObject4, this.objectiveFieldId + ".optype", "")).equals(Constants.OPTYPE_NUMERIC));
        JSONArray jSONArray = (JSONArray) Utils.getJSONObject((JSONObject) jSONObject4.get(this.objectiveField), "summary.categories", new JSONArray());
        if (!this.regression.booleanValue()) {
            Iterator it = jSONArray.iterator();
            while (it.hasNext()) {
                this.classNames.add((String) ((JSONArray) it.next()).get(0));
            }
            Collections.sort(this.classNames);
        }
        this.missingNumerics = (Boolean) Utils.getJSONObject(jSONObject3, "missing_numerics");
        if (jSONObject3.containsKey("network")) {
            this.network = (JSONObject) jSONObject3.get("network");
            this.networks = (JSONArray) Utils.getJSONObject(this.network, "networks", new JSONArray());
            this.preprocess = (JSONArray) Utils.getJSONObject(this.network, "preprocess", new JSONArray());
        }
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public String getResourceId() {
        return this.deepnetId;
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public List<String> getClassNames() {
        return this.classNames;
    }

    public HashMap<String, Object> predict(JSONObject jSONObject, JSONObject jSONObject2, String str, Boolean bool) {
        if (bool == null) {
            bool = false;
        }
        JSONObject filterInputData = filterInputData(jSONObject, bool);
        List list = (List) filterInputData.get("unusedFields");
        JSONObject jSONObject3 = (JSONObject) filterInputData.get("newInputData");
        Utils.cast(jSONObject3, this.fields);
        if (jSONObject2 != null) {
            if (this.regression.booleanValue()) {
                throw new IllegalArgumentException("The operating_point argument can only be used in classifications.");
            }
            return predictOperating(jSONObject3, jSONObject2);
        }
        if (str != null) {
            if (this.regression.booleanValue()) {
                throw new IllegalArgumentException("The operating_kind argument can only be used in classifications.");
            }
            return predictOperatingKind(jSONObject3, str);
        }
        ArrayList<List<Double>> fillArray = fillArray(jSONObject3, uniqueTerms(jSONObject3));
        HashMap<String, Object> predictSingle = (this.networks == null || this.networks.size() <= 0) ? predictSingle(fillArray) : predictList(fillArray);
        if (bool.booleanValue()) {
            predictSingle.put("unused_fields", list);
        }
        return predictSingle;
    }

    private JSONArray predictProbability(JSONObject jSONObject) {
        try {
            return predictProbability(jSONObject, null);
        } catch (Exception e) {
            return null;
        }
    }

    @Override // org.bigml.binding.SupervisedModelInterface
    public JSONArray predictProbability(JSONObject jSONObject, MissingStrategy missingStrategy) throws Exception {
        JSONArray jSONArray = (JSONArray) (this.regression.booleanValue() ? predict(jSONObject, null, null, true) : predict(jSONObject, null, null, true)).get("distribution");
        Utils.sortPredictions(jSONArray, "probability", AbstractResource.PREDICTION_PATH);
        return jSONArray;
    }

    private HashMap<String, Object> predictOperating(JSONObject jSONObject, JSONObject jSONObject2) {
        Object[] parseOperatingPoint = Utils.parseOperatingPoint(jSONObject2, new String[]{"probability"}, this.classNames);
        String str = (String) parseOperatingPoint[0];
        Double d = (Double) parseOperatingPoint[1];
        String str2 = (String) parseOperatingPoint[2];
        JSONArray predictProbability = predictProbability(jSONObject);
        Iterator it = predictProbability.iterator();
        while (it.hasNext()) {
            JSONObject jSONObject3 = (JSONObject) it.next();
            if (((String) jSONObject3.get("category")).equals(str2) && ((Double) jSONObject3.get(str)).doubleValue() > d.doubleValue()) {
                return jSONObject3;
            }
        }
        HashMap<String, Object> hashMap = (HashMap) predictProbability.get(0);
        if (((String) hashMap.get("category")).equals(str2)) {
            hashMap = (HashMap) predictProbability.get(1);
        }
        hashMap.put(AbstractResource.PREDICTION_PATH, hashMap.get("category"));
        hashMap.remove("category");
        return hashMap;
    }

    private HashMap<String, Object> predictOperatingKind(JSONObject jSONObject, String str) {
        if (!str.toLowerCase().equals("probability")) {
            throw new IllegalArgumentException("Only probability is allowed as operating kind for deepnets.");
        }
        HashMap<String, Object> hashMap = (HashMap) predictProbability(jSONObject).get(0);
        hashMap.put(AbstractResource.PREDICTION_PATH, hashMap.get("category"));
        hashMap.remove("category");
        return hashMap;
    }

    private List<Double> expandTerms(List list, Map<String, Integer> map) {
        Double[] dArr = new Double[list.size()];
        Arrays.fill(dArr, Double.valueOf(0.0d));
        if (map != null) {
            for (String str : map.keySet()) {
                dArr[Integer.valueOf(list.indexOf(str)).intValue()] = Double.valueOf(map.get(str).doubleValue());
            }
        }
        return Arrays.asList(dArr);
    }

    private ArrayList<List<Double>> fillArray(JSONObject jSONObject, Map<String, Object> map) {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.inputFields.iterator();
        while (it.hasNext()) {
            Object next = it.next();
            Map<String, Integer> map2 = (Map) map.get(next);
            if (this.tagClouds.containsKey(next)) {
                arrayList.addAll(expandTerms(this.tagClouds.get(next), map2));
            } else if (this.items.containsKey(next)) {
                arrayList.addAll(expandTerms(this.items.get(next), map2));
            } else if (!this.categories.containsKey(next)) {
                int intValue = ((Number) Utils.getJSONObject(this.fields, next + ".summary.missing_count")).intValue();
                Double d = null;
                if (jSONObject.get(next) != null) {
                    d = Double.valueOf(((Number) jSONObject.get(next)).doubleValue());
                }
                if (!this.missingNumerics.booleanValue() || intValue <= 0) {
                    arrayList.add(d);
                } else if (jSONObject.containsKey(next)) {
                    arrayList.add(d);
                    arrayList.add(Double.valueOf(0.0d));
                } else {
                    arrayList.add(Double.valueOf(0.0d));
                    arrayList.add(Double.valueOf(1.0d));
                }
            } else if (map2 != null) {
                map2.values().toArray();
                arrayList.add(map2.keySet().toArray()[0]);
            } else {
                arrayList.add(null);
            }
        }
        return Preprocess.preprocess(arrayList, this.preprocess);
    }

    private HashMap<String, Object> predictSingle(ArrayList<List<Double>> arrayList) {
        JSONArray jSONArray = (JSONArray) this.network.get("trees");
        if (jSONArray != null && jSONArray.size() > 0) {
            arrayList = Preprocess.treeTransform(arrayList, jSONArray);
        }
        return toPrediction(modelPredict(arrayList, this.network));
    }

    private HashMap<String, Object> predictList(ArrayList<List<Double>> arrayList) {
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        JSONArray jSONArray = (JSONArray) this.network.get("trees");
        if (jSONArray != null && jSONArray.size() > 0) {
            arrayList2 = Preprocess.treeTransform(arrayList, jSONArray);
        }
        JSONArray jSONArray2 = new JSONArray();
        Iterator it = this.networks.iterator();
        while (it.hasNext()) {
            JSONObject jSONObject = (JSONObject) it.next();
            Boolean bool = (Boolean) jSONObject.get("trees");
            if (bool == null || !bool.booleanValue()) {
                jSONArray2.add(modelPredict(arrayList, jSONObject));
            } else {
                jSONArray2.add(modelPredict(arrayList2, jSONObject));
            }
        }
        return toPrediction(MathOps.sumAndNormalize(jSONArray2, this.regression.booleanValue()));
    }

    private ArrayList<List<Double>> modelPredict(ArrayList<List<Double>> arrayList, JSONObject jSONObject) {
        ArrayList<List<Double>> propagate = MathOps.propagate(arrayList, (JSONArray) jSONObject.get("layers"));
        JSONObject jSONObject2 = (JSONObject) Utils.getJSONObject(jSONObject, "output_exposition", new JSONObject());
        if (this.regression.booleanValue()) {
            propagate = MathOps.destandardize(propagate, Double.valueOf(((Number) jSONObject2.get("mean")).doubleValue()), Double.valueOf(((Number) jSONObject2.get("stdev")).doubleValue()));
        }
        return propagate;
    }

    private HashMap<String, Object> toPrediction(ArrayList<List<Double>> arrayList) {
        Double d = (Double) Collections.max(arrayList.get(0));
        int indexOf = arrayList.get(0).indexOf(d);
        HashMap<String, Object> hashMap = new HashMap<>();
        hashMap.put("probability", d);
        if (this.classNames != null && this.classNames.size() > 0) {
            hashMap.put(AbstractResource.PREDICTION_PATH, this.classNames.get(indexOf));
            JSONArray jSONArray = new JSONArray();
            for (int i = 0; i < this.classNames.size(); i++) {
                JSONObject jSONObject = new JSONObject();
                jSONObject.put("category", this.classNames.get(i));
                jSONObject.put("probability", Double.valueOf(Utils.roundOff(arrayList.get(0).get(i).doubleValue(), 5)));
                jSONArray.add(jSONObject);
            }
            hashMap.put("distribution", jSONArray);
        }
        return hashMap;
    }
}
