package org.deeplearning4j.models.glove;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/glove/GloveJobAggregator.class */
public class GloveJobAggregator implements JobAggregator {
    private List<org.deeplearning4j.scaleout.perform.models.glove.GloveResult> work = new ArrayList();

    public void accumulate(Job job) {
        if (job.getResult() instanceof org.deeplearning4j.scaleout.perform.models.glove.GloveResult) {
            this.work.add(job.getResult());
        } else if (job.getResult() instanceof Collection) {
            this.work.addAll((Collection) job.getResult());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Job aggregate() {
        Job job = new Job("", "");
        org.deeplearning4j.scaleout.perform.models.glove.GloveResult gloveResult = new org.deeplearning4j.scaleout.perform.models.glove.GloveResult();
        MultiDimensionalMap<String, String, List<INDArray>> newHashBackedMap = MultiDimensionalMap.newHashBackedMap();
        HashSet<String> hashSet = new HashSet();
        for (org.deeplearning4j.scaleout.perform.models.glove.GloveResult gloveResult2 : this.work) {
            for (String str : gloveResult2.getSyn0Change().keySet()) {
                getOrPutIfNotExists(newHashBackedMap, str, "syn0").add(gloveResult2.getSyn0Change().get(str));
                hashSet.add(str);
            }
        }
        for (String str2 : hashSet) {
            gloveResult.getSyn0Change().put(str2, average((List) newHashBackedMap.get(str2, "syn0")));
        }
        job.setResult((Serializable) Arrays.asList(gloveResult));
        return job;
    }

    private INDArray average(List<INDArray> list) {
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("Can't average empty or null list");
        }
        if (list.get(0) == null) {
            return null;
        }
        INDArray create = Nd4j.create(list.get(0).shape());
        Iterator<INDArray> it = list.iterator();
        while (it.hasNext()) {
            create.addi(it.next());
        }
        return list.size() > 1 ? create.divi(Double.valueOf(list.size())) : create;
    }

    private List<INDArray> getOrPutIfNotExists(MultiDimensionalMap<String, String, List<INDArray>> multiDimensionalMap, String str, String str2) {
        List<INDArray> list = (List) multiDimensionalMap.get(str, str2);
        if (list == null) {
            list = new ArrayList();
            multiDimensionalMap.put(str, str2, list);
        }
        return list;
    }

    public void init(Configuration configuration) {
    }
}
