package org.wlld.rnnJumpNerveEntity;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.wlld.i.OutBack;
import org.wlld.matrixTools.Matrix;

/* loaded from: input_file:org/wlld/rnnJumpNerveEntity/SoftMax.class */
public class SoftMax extends Nerve {
    private final List<OutNerve> outNerves;
    private final boolean isShowLog;
    private NerveCenter nerveCenter;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/wlld/rnnJumpNerveEntity/SoftMax$Mes.class */
    public static class Mes {
        int typeID;
        double poi;
        List<Double> softMax;

        Mes() {
        }
    }

    public SoftMax(boolean z, List<OutNerve> list, boolean z2, int i, int i2, int i3, int i4) throws Exception {
        super(0, "softMax", 0.0d, false, null, z, 0, 0.0d, 0, 0, i, i2, i3, i4, false, 0);
        this.outNerves = list;
        this.isShowLog = z2;
    }

    public void setNerveCenter(NerveCenter nerveCenter) {
        this.nerveCenter = nerveCenter;
    }

    @Override // org.wlld.rnnJumpNerveEntity.Nerve
    protected void sendAppointSoftMax(long j, double d, Matrix matrix, OutBack outBack, String str) throws Exception {
        if (insertParameter(j, d)) {
            Mes softMax = softMax(j, false);
            destroyParameter(j);
            this.nerveCenter.backType(j, softMax.poi, softMax.typeID, matrix, outBack, str);
        }
    }

    @Override // org.wlld.rnnJumpNerveEntity.Nerve
    protected void input(long j, double d, boolean z, Map<Integer, Double> map, OutBack outBack, Matrix matrix, int[] iArr, int i, int i2) throws Exception {
        if (insertParameter(j, d)) {
            Mes softMax = softMax(j, z);
            int i3 = 0;
            if (!z) {
                destroyParameter(j);
                if (outBack == null) {
                    throw new Exception("not find outBack");
                }
                outBack.getBack(softMax.poi, softMax.typeID, j);
                outBack.getSoftMaxBack(j, softMax.softMax);
                return;
            }
            Iterator<Map.Entry<Integer, Double>> it = map.entrySet().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Map.Entry<Integer, Double> next = it.next();
                if (next.getValue().doubleValue() > 0.9d) {
                    i3 = next.getKey().intValue();
                    break;
                }
            }
            if (this.isShowLog) {
                System.out.println("softMax==" + i3 + ",out==" + softMax.poi + ",nerveId==" + softMax.typeID);
            }
            List<Double> error = error(softMax, i3);
            this.features.remove(Long.valueOf(j));
            int size = this.outNerves.size();
            for (int i4 = 0; i4 < size; i4++) {
                this.outNerves.get(i4).getGBySoftMax(error.get(i4).doubleValue(), j, iArr, i);
            }
        }
    }

    private List<Double> error(Mes mes, int i) {
        int i2 = i - 1;
        List<Double> list = mes.softMax;
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        while (i3 < list.size()) {
            double doubleValue = list.get(i3).doubleValue();
            arrayList.add(Double.valueOf(i3 != i2 ? -doubleValue : 1.0d - doubleValue));
            i3++;
        }
        return arrayList;
    }

    private Mes softMax(long j, boolean z) {
        double d = 0.0d;
        int i = 0;
        double d2 = 0.0d;
        Mes mes = new Mes();
        List<Double> list = this.features.get(Long.valueOf(j));
        Iterator<Double> it = list.iterator();
        while (it.hasNext()) {
            d = Math.exp(it.next().doubleValue()) + d;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            double exp = Math.exp(list.get(i2).doubleValue()) / d;
            arrayList.add(Double.valueOf(exp));
            if (exp > d2) {
                d2 = exp;
                i = i2 + 1;
            }
        }
        mes.softMax = arrayList;
        mes.typeID = i;
        mes.poi = d2;
        return mes;
    }
}
