package org.bigml.binding.laminar;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;

/* loaded from: input_file:org/bigml/binding/laminar/MathOps.class */
public class MathOps {
    private static int LARGE_EXP = 512;

    private static ArrayList<List<Double>> operation(String str, ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        for (int i = 0; i < arrayList.size(); i++) {
            List<Double> list = arrayList.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < list.size(); i2++) {
                Double valueOf = Double.valueOf(((Number) jSONArray.get(i2)).doubleValue());
                if ("+".equals(str)) {
                    arrayList3.add(Double.valueOf(list.get(i2).doubleValue() + valueOf.doubleValue()));
                }
                if ("-".equals(str)) {
                    arrayList3.add(Double.valueOf(list.get(i2).doubleValue() - valueOf.doubleValue()));
                }
                if ("*".equals(str)) {
                    arrayList3.add(Double.valueOf(list.get(i2).doubleValue() * valueOf.doubleValue()));
                }
                if ("/".equals(str)) {
                    arrayList3.add(Double.valueOf(list.get(i2).doubleValue() / valueOf.doubleValue()));
                }
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    private static ArrayList<List<Double>> plus(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        return operation("+", arrayList, jSONArray);
    }

    private static ArrayList<List<Double>> minus(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        return operation("-", arrayList, jSONArray);
    }

    private static ArrayList<List<Double>> times(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        return operation("*", arrayList, jSONArray);
    }

    private static ArrayList<List<Double>> divide(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        return operation("/", arrayList, jSONArray);
    }

    public static ArrayList<List<Double>> dot(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        for (int i = 0; i < arrayList.size(); i++) {
            List<Double> list = arrayList.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < jSONArray.size(); i2++) {
                List list2 = (List) jSONArray.get(i2);
                double d = 0.0d;
                for (int i3 = 0; i3 < list.size(); i3++) {
                    d += list.get(i3).doubleValue() * Double.valueOf(((Number) list2.get(i3)).doubleValue()).doubleValue();
                }
                arrayList3.add(Double.valueOf(d));
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    private static ArrayList<List<Double>> batchNorm(ArrayList<List<Double>> arrayList, JSONArray jSONArray, JSONArray jSONArray2, JSONArray jSONArray3, JSONArray jSONArray4) {
        return plus(times(divide(minus(arrayList, jSONArray), jSONArray2), jSONArray4), jSONArray3);
    }

    public static ArrayList<List<Double>> destandardize(ArrayList<List<Double>> arrayList, Double d, Double d2) {
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        for (int i = 0; i < arrayList.size(); i++) {
            List<Double> list = arrayList.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < list.size(); i2++) {
                arrayList3.add(Double.valueOf((list.get(i2).doubleValue() * d2.doubleValue()) + d.doubleValue()));
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    private static ArrayList<List<Double>> toWidth(ArrayList<List<Double>> arrayList, int i) {
        int i2 = 1;
        if (i > arrayList.get(0).size()) {
            i2 = (int) Math.ceil(i / arrayList.get(0).size());
        }
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        Iterator<List<Double>> it = arrayList.iterator();
        while (it.hasNext()) {
            List<Double> next = it.next();
            ArrayList arrayList3 = new ArrayList();
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < i) {
                    arrayList3.addAll(next);
                    i3 = i4 + i2;
                }
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    private static ArrayList<List<Double>> addResiduals(ArrayList<List<Double>> arrayList, ArrayList<List<Double>> arrayList2) {
        ArrayList<List<Double>> arrayList3 = new ArrayList<>();
        ArrayList<List<Double>> width = toWidth(arrayList2, arrayList.get(0).size());
        for (int i = 0; i < arrayList.size(); i++) {
            List<Double> list = arrayList.get(i);
            List<Double> list2 = width.get(i);
            ArrayList arrayList4 = new ArrayList();
            for (int i2 = 0; i2 < list.size(); i2++) {
                arrayList4.add(Double.valueOf(list.get(i2).doubleValue() + list2.get(i2).doubleValue()));
            }
            arrayList3.add(arrayList4);
        }
        return arrayList3;
    }

    public static ArrayList<List<Double>> propagate(ArrayList<List<Double>> arrayList, JSONArray jSONArray) {
        ArrayList<List<Double>> arrayList2 = arrayList;
        ArrayList<List<Double>> arrayList3 = arrayList;
        Iterator it = jSONArray.iterator();
        while (it.hasNext()) {
            JSONObject jSONObject = (JSONObject) it.next();
            JSONArray jSONArray2 = (JSONArray) jSONObject.get("weights");
            JSONArray jSONArray3 = (JSONArray) jSONObject.get("mean");
            JSONArray jSONArray4 = (JSONArray) jSONObject.get("stdev");
            JSONArray jSONArray5 = (JSONArray) jSONObject.get("scale");
            JSONArray jSONArray6 = (JSONArray) jSONObject.get("offset");
            Boolean bool = (Boolean) jSONObject.get("residuals");
            String str = (String) jSONObject.get("activation_function");
            ArrayList<List<Double>> dot = dot(arrayList3, jSONArray2);
            ArrayList<List<Double>> plus = (jSONArray3 == null || jSONArray4 == null) ? plus(dot, jSONArray6) : batchNorm(dot, jSONArray3, jSONArray4, jSONArray6, jSONArray5);
            if (bool == null || !bool.booleanValue()) {
                arrayList3 = broadcast(str, plus);
            } else {
                arrayList3 = broadcast(str, addResiduals(plus, arrayList2));
                arrayList2 = arrayList3;
            }
        }
        return arrayList3;
    }

    private static ArrayList<List<Double>> broadcast(String str, ArrayList<List<Double>> arrayList) {
        ArrayList<List<Double>> arrayList2 = new ArrayList<>();
        if (arrayList.size() == 0) {
            return arrayList2;
        }
        if ("identity".equals(str)) {
            return arrayList;
        }
        if ("softmax".equals(str)) {
            return softmax(arrayList);
        }
        Iterator<List<Double>> it = arrayList.iterator();
        while (it.hasNext()) {
            List<Double> next = it.next();
            ArrayList arrayList3 = new ArrayList();
            for (Double d : next) {
                if ("tanh".equals(str)) {
                    arrayList3.add(Double.valueOf(Math.tanh(d.doubleValue())));
                }
                if ("sigmoid".equals(str)) {
                    if (d.doubleValue() > 0.0d) {
                        if (d.doubleValue() < LARGE_EXP) {
                            double exp = Math.exp(d.doubleValue());
                            arrayList3.add(Double.valueOf(exp / (exp + 1.0d)));
                        } else {
                            arrayList3.add(Double.valueOf(1.0d));
                        }
                    } else if ((-d.doubleValue()) < LARGE_EXP) {
                        arrayList3.add(Double.valueOf(1.0d / (1.0d + Math.exp(-d.doubleValue()))));
                    } else {
                        arrayList3.add(Double.valueOf(0.0d));
                    }
                }
                if ("softplus".equals(str)) {
                    arrayList3.add(Double.valueOf(d.doubleValue() < ((double) LARGE_EXP) ? Math.log(Math.exp(d.doubleValue()) + 1.0d) : d.doubleValue()));
                }
                if ("relu".equals(str)) {
                    arrayList3.add(Double.valueOf(d.doubleValue() > 0.0d ? d.doubleValue() : 0.0d));
                }
            }
            arrayList2.add(arrayList3);
        }
        return arrayList2;
    }

    private static ArrayList<List<Double>> softmax(ArrayList<List<Double>> arrayList) {
        double d = 0.0d;
        Iterator<List<Double>> it = arrayList.iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) Collections.max(it.next())).doubleValue();
            d = doubleValue > d ? doubleValue : d;
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator<List<Double>> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            List<Double> next = it2.next();
            ArrayList arrayList3 = new ArrayList();
            Iterator<Double> it3 = next.iterator();
            while (it3.hasNext()) {
                arrayList3.add(Double.valueOf(Math.exp(it3.next().doubleValue() - d)));
            }
            arrayList2.add(arrayList3);
        }
        double d2 = 0.0d;
        Iterator it4 = arrayList2.iterator();
        while (it4.hasNext()) {
            Iterator it5 = ((List) it4.next()).iterator();
            while (it5.hasNext()) {
                d2 += ((Double) it5.next()).doubleValue();
            }
        }
        ArrayList<List<Double>> arrayList4 = new ArrayList<>();
        Iterator it6 = arrayList2.iterator();
        while (it6.hasNext()) {
            List list = (List) it6.next();
            ArrayList arrayList5 = new ArrayList();
            Iterator it7 = list.iterator();
            while (it7.hasNext()) {
                arrayList5.add(Double.valueOf(((Double) it7.next()).doubleValue() / d2));
            }
            arrayList4.add(arrayList5);
        }
        return arrayList4;
    }

    public static ArrayList<List<Double>> sumAndNormalize(ArrayList<ArrayList<List<Double>>> arrayList, boolean z) {
        ArrayList<List<Double>> arrayList2 = arrayList.get(0);
        Double[] dArr = new Double[arrayList2.get(0).size()];
        for (int i = 0; i < arrayList2.get(0).size(); i++) {
            dArr[i] = Double.valueOf(0.0d);
        }
        Iterator<ArrayList<List<Double>>> it = arrayList.iterator();
        while (it.hasNext()) {
            ArrayList<List<Double>> next = it.next();
            for (int i2 = 0; i2 < next.get(0).size(); i2++) {
                int i3 = i2;
                dArr[i3] = Double.valueOf(dArr[i3].doubleValue() + next.get(0).get(i2).doubleValue());
            }
        }
        ArrayList<List<Double>> arrayList3 = new ArrayList<>();
        ArrayList arrayList4 = new ArrayList();
        double d = 0.0d;
        for (Double d2 : dArr) {
            d += d2.doubleValue();
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            if (z) {
                arrayList4.add(Double.valueOf(dArr[i4].doubleValue() / arrayList.size()));
            } else {
                arrayList4.add(Double.valueOf(dArr[i4].doubleValue() / d));
            }
        }
        arrayList3.add(arrayList4);
        return arrayList3;
    }
}
