package org.deeplearning4j.spark.impl.multilayer.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.class */
public class FeedForwardWithKeyFunction<K> implements PairFlatMapFunction<Iterator<Tuple2<K, Tuple2<INDArray, INDArray>>>, K, INDArray> {
    protected static Logger log = LoggerFactory.getLogger(FeedForwardWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final int batchSize;

    public FeedForwardWithKeyFunction(Broadcast<INDArray> broadcast, Broadcast<String> broadcast2, int i) {
        this.params = broadcast;
        this.jsonConfig = broadcast2;
        this.batchSize = i;
    }

    public Iterator<Tuple2<K, INDArray>> call(Iterator<Tuple2<K, Tuple2<INDArray, INDArray>>> it) throws Exception {
        if (!it.hasNext()) {
            return Collections.emptyIterator();
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) this.jsonConfig.getValue()));
        multiLayerNetwork.init();
        INDArray unsafeDuplication = ((INDArray) this.params.value()).unsafeDuplication();
        if (unsafeDuplication.length() != multiLayerNetwork.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
        }
        multiLayerNetwork.setParameters(unsafeDuplication);
        ArrayList arrayList = new ArrayList(this.batchSize);
        ArrayList arrayList2 = new ArrayList(this.batchSize);
        ArrayList arrayList3 = new ArrayList(this.batchSize);
        ArrayList arrayList4 = new ArrayList();
        long[] jArr = null;
        boolean z = false;
        int i = 0;
        while (it.hasNext()) {
            Tuple2<K, Tuple2<INDArray, INDArray>> next = it.next();
            if (jArr == null) {
                jArr = ((INDArray) ((Tuple2) next._2())._1()).shape();
            } else if (!z) {
                int i2 = 1;
                while (true) {
                    if (i2 >= jArr.length) {
                        break;
                    }
                    if (jArr[i2] != ((INDArray) arrayList.get(i - 1)).size(i2)) {
                        z = true;
                        break;
                    }
                    i2++;
                }
            }
            arrayList.add(((Tuple2) next._2())._1());
            arrayList2.add(((Tuple2) next._2())._2());
            arrayList3.add(next._1());
            arrayList4.add(Integer.valueOf((int) ((INDArray) ((Tuple2) next._2())._1()).size(0)));
            i++;
        }
        if (i == 0) {
            return Collections.emptyIterator();
        }
        ArrayList arrayList5 = new ArrayList(i);
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= arrayList.size()) {
                Nd4j.getExecutioner().commit();
                return arrayList5.iterator();
            }
            int i5 = i4;
            int i6 = 0;
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            long[] jArr2 = null;
            while (i5 < arrayList.size() && i6 < this.batchSize) {
                if (jArr2 == null) {
                    jArr2 = ((INDArray) arrayList.get(i5)).shape();
                } else if (z) {
                    boolean z2 = false;
                    int i7 = 1;
                    while (true) {
                        if (i7 >= jArr2.length) {
                            break;
                        }
                        if (jArr2[i7] != ((INDArray) arrayList.get(i5)).size(i7)) {
                            z2 = true;
                            break;
                        }
                        i7++;
                    }
                    if (z2) {
                        break;
                    }
                } else {
                    continue;
                }
                INDArray iNDArray = (INDArray) arrayList.get(i5);
                INDArray iNDArray2 = (INDArray) arrayList2.get(i5);
                i5++;
                arrayList6.add(iNDArray);
                arrayList7.add(iNDArray2);
                i6 = (int) (i6 + iNDArray.size(0));
            }
            Pair mergeFeatures = DataSetUtil.mergeFeatures((INDArray[]) arrayList6.toArray(new INDArray[arrayList6.size()]), (INDArray[]) arrayList7.toArray(new INDArray[arrayList7.size()]));
            INDArray output = multiLayerNetwork.output((INDArray) mergeFeatures.getFirst(), false, (INDArray) mergeFeatures.getSecond(), (INDArray) null);
            int i8 = 0;
            for (int i9 = i4; i9 < i5; i9++) {
                int intValue = ((Integer) arrayList4.get(i9)).intValue();
                INDArray subset = getSubset(i8, i8 + intValue, output);
                i8 += intValue;
                arrayList5.add(new Tuple2(arrayList3.get(i9), subset));
            }
            i3 = i4 + (i5 - i4);
        }
    }

    private INDArray getSubset(int i, int i2, INDArray iNDArray) {
        switch (iNDArray.rank()) {
            case 2:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all()});
            case 3:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all()});
            case 4:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            default:
                throw new RuntimeException("Invalid rank: " + iNDArray.rank());
        }
    }
}
