package org.deeplearning4j.spark.impl.graph.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.class */
public class GraphFeedForwardWithKeyFunction<K> implements PairFlatMapFunction<Iterator<Tuple2<K, INDArray[]>>, K, INDArray[]> {
    private static final Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final int batchSize;

    public Iterator<Tuple2<K, INDArray[]>> call(Iterator<Tuple2<K, INDArray[]>> it) throws Exception {
        if (!it.hasNext()) {
            return Collections.emptyIterator();
        }
        ComputationGraph computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) this.jsonConfig.getValue()));
        computationGraph.init();
        INDArray unsafeDuplication = ((INDArray) this.params.value()).unsafeDuplication();
        if (unsafeDuplication.length() != computationGraph.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        computationGraph.setParams(unsafeDuplication);
        ArrayList arrayList = new ArrayList(this.batchSize);
        ArrayList arrayList2 = new ArrayList(this.batchSize);
        ArrayList arrayList3 = new ArrayList();
        long[][] jArr = (long[][]) null;
        boolean z = false;
        int i = 0;
        while (it.hasNext()) {
            Tuple2<K, INDArray[]> next = it.next();
            if (jArr == null) {
                jArr = new long[((INDArray[]) next._2()).length][0];
                for (int i2 = 0; i2 < jArr.length; i2++) {
                    jArr[i2] = ((INDArray[]) next._2())[i2].shape();
                }
            } else if (!z) {
                for (int i3 = 0; i3 < jArr.length; i3++) {
                    int i4 = 1;
                    while (true) {
                        if (i4 >= jArr[i3].length) {
                            break;
                        }
                        if (jArr[i3][i4] != ((INDArray[]) arrayList.get(i - 1))[i3].size(i4)) {
                            z = true;
                            break;
                        }
                        i4++;
                    }
                }
            }
            arrayList.add(next._2());
            arrayList2.add(next._1());
            arrayList3.add(Long.valueOf(((INDArray[]) next._2())[0].size(0)));
            i++;
        }
        if (i == 0) {
            return Collections.emptyIterator();
        }
        ArrayList arrayList4 = new ArrayList(i);
        int i5 = 0;
        while (true) {
            int i6 = i5;
            if (i6 >= arrayList.size()) {
                Nd4j.getExecutioner().commit();
                return arrayList4.iterator();
            }
            int i7 = i6;
            int i8 = 0;
            ArrayList arrayList5 = new ArrayList();
            long[][] jArr2 = (long[][]) null;
            while (i7 < arrayList.size() && i8 < this.batchSize) {
                INDArray[] iNDArrayArr = (INDArray[]) arrayList.get(i7);
                if (jArr2 == null) {
                    jArr2 = new long[iNDArrayArr.length][0];
                    for (int i9 = 0; i9 < jArr2.length; i9++) {
                        jArr2[i9] = iNDArrayArr[i9].shape();
                    }
                } else if (z) {
                    boolean z2 = false;
                    for (int i10 = 0; i10 < jArr2.length; i10++) {
                        int i11 = 1;
                        while (true) {
                            if (i11 >= jArr2[i10].length) {
                                break;
                            }
                            if (jArr2[i10][i11] != ((INDArray[]) arrayList.get(i7))[i10].size(i11)) {
                                z2 = true;
                                break;
                            }
                            i11++;
                        }
                    }
                    if (z2) {
                        break;
                    }
                } else {
                    continue;
                }
                arrayList5.add(iNDArrayArr);
                i8 = (int) (i8 + iNDArrayArr[0].size(0));
                i7++;
            }
            INDArray[] iNDArrayArr2 = new INDArray[((INDArray[]) arrayList5.get(0)).length];
            for (int i12 = 0; i12 < iNDArrayArr2.length; i12++) {
                INDArray[] iNDArrayArr3 = new INDArray[arrayList5.size()];
                for (int i13 = 0; i13 < iNDArrayArr3.length; i13++) {
                    iNDArrayArr3[i13] = ((INDArray[]) arrayList5.get(i13))[i12];
                }
                iNDArrayArr2[i12] = Nd4j.concat(0, iNDArrayArr3);
            }
            INDArray[] output = computationGraph.output(false, iNDArrayArr2);
            int i14 = 0;
            for (int i15 = i6; i15 < i7; i15++) {
                long longValue = ((Long) arrayList3.get(i15)).longValue();
                INDArray[] iNDArrayArr4 = new INDArray[output.length];
                for (int i16 = 0; i16 < output.length; i16++) {
                    iNDArrayArr4[i16] = getSubset(i14, i14 + longValue, output[i16]);
                }
                i14 = (int) (i14 + longValue);
                arrayList4.add(new Tuple2(arrayList2.get(i15), iNDArrayArr4));
            }
            i5 = i6 + (i7 - i6);
        }
    }

    private INDArray getSubset(long j, long j2, INDArray iNDArray) {
        switch (iNDArray.rank()) {
            case 2:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(j, j2), NDArrayIndex.all()});
            case 3:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(j, j2), NDArrayIndex.all(), NDArrayIndex.all()});
            case 4:
                return iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(j, j2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            default:
                throw new RuntimeException("Invalid rank: " + iNDArray.rank());
        }
    }

    public GraphFeedForwardWithKeyFunction(Broadcast<INDArray> broadcast, Broadcast<String> broadcast2, int i) {
        this.params = broadcast;
        this.jsonConfig = broadcast2;
        this.batchSize = i;
    }
}
