package org.deeplearning4j.arbiter.ui.listener;

import it.unimi.dsi.fastutil.floats.FloatArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.class */
public class ArbiterStatusListener implements StatusListener {
    private static final Logger log = LoggerFactory.getLogger(ArbiterStatusListener.class);
    public static final int MAX_SCORE_VS_ITER_PTS = 1024;
    private final String sessionId;
    private final StatsStorageRouter statsStorage;
    private String ocJson;
    private long startTime;
    private Map<Integer, Integer> candidateScoreVsIterSubsampleFreq;
    private Map<Integer, Pair<IntArrayList, FloatArrayList>> candidateScoreVsIter;
    private Map<Integer, ModelInfoPersistable> lastModelInfoPersistable;

    public ArbiterStatusListener(@NonNull StatsStorageRouter statsStorageRouter) {
        this(UUID.randomUUID().toString(), statsStorageRouter);
        if (statsStorageRouter == null) {
            throw new NullPointerException("statsStorage is marked non-null but is null");
        }
    }

    public ArbiterStatusListener(@NonNull String str, @NonNull StatsStorageRouter statsStorageRouter) {
        this.startTime = 0L;
        this.candidateScoreVsIterSubsampleFreq = new ConcurrentHashMap();
        this.candidateScoreVsIter = new ConcurrentHashMap();
        this.lastModelInfoPersistable = new ConcurrentHashMap();
        if (str == null) {
            throw new NullPointerException("sessionId is marked non-null but is null");
        }
        if (statsStorageRouter == null) {
            throw new NullPointerException("statsStorage is marked non-null but is null");
        }
        this.sessionId = str;
        this.statsStorage = statsStorageRouter;
    }

    public void onInitialization(IOptimizationRunner iOptimizationRunner) {
        this.statsStorage.putStaticInfo(getNewStatusPersistable(iOptimizationRunner));
    }

    public void onShutdown(IOptimizationRunner iOptimizationRunner) {
    }

    public void onRunnerStatusChange(IOptimizationRunner iOptimizationRunner) {
        this.statsStorage.putStaticInfo(getNewStatusPersistable(iOptimizationRunner));
    }

    public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner iOptimizationRunner, OptimizationResult optimizationResult) {
        ModelInfoPersistable modelInfoPersistable = this.lastModelInfoPersistable.get(Integer.valueOf(candidateInfo.getIndex()));
        if (modelInfoPersistable == null) {
            modelInfoPersistable = new ModelInfoPersistable.Builder().timestamp(candidateInfo.getCreatedTime()).sessionId(this.sessionId).workerId(String.valueOf(candidateInfo.getIndex())).modelIdx(Integer.valueOf(candidateInfo.getIndex())).score(candidateInfo.getScore()).status(candidateInfo.getCandidateStatus()).exceptionStackTrace(candidateInfo.getExceptionStackTrace()).build();
            this.lastModelInfoPersistable.put(Integer.valueOf(candidateInfo.getIndex()), modelInfoPersistable);
        }
        if (modelInfoPersistable.getScore() == null) {
            modelInfoPersistable.setScore(candidateInfo.getScore());
        }
        if (optimizationResult != null && modelInfoPersistable.getExceptionStackTrace() == null && optimizationResult.getCandidateInfo().getExceptionStackTrace() != null) {
            modelInfoPersistable.setExceptionStackTrace(optimizationResult.getCandidateInfo().getExceptionStackTrace());
        }
        modelInfoPersistable.setStatus(candidateInfo.getCandidateStatus());
        this.statsStorage.putUpdate(modelInfoPersistable);
    }

    public void onCandidateIteration(CandidateInfo candidateInfo, Object obj, int i) {
        double d;
        long j;
        int i2;
        int i3;
        String str;
        if (obj instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) obj;
            d = multiLayerNetwork.score();
            j = multiLayerNetwork.numParams();
            i2 = multiLayerNetwork.getnLayers();
            str = multiLayerNetwork.getLayerWiseConfigurations().toJson();
            i3 = multiLayerNetwork.getLayerWiseConfigurations().getIterationCount();
        } else if (obj instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) obj;
            d = computationGraph.score();
            j = computationGraph.numParams();
            i2 = computationGraph.getNumLayers();
            str = computationGraph.getConfiguration().toJson();
            i3 = computationGraph.getConfiguration().getIterationCount();
        } else {
            d = 0.0d;
            j = 0;
            i2 = 0;
            i3 = 0;
            str = "";
        }
        int index = candidateInfo.getIndex();
        Pair<IntArrayList, FloatArrayList> computeIfAbsent = this.candidateScoreVsIter.computeIfAbsent(Integer.valueOf(index), num -> {
            return new Pair(new IntArrayList(), new FloatArrayList());
        });
        IntArrayList intArrayList = (IntArrayList) computeIfAbsent.getFirst();
        FloatArrayList floatArrayList = (FloatArrayList) computeIfAbsent.getSecond();
        int intValue = this.candidateScoreVsIterSubsampleFreq.computeIfAbsent(Integer.valueOf(index), num2 -> {
            return 1;
        }).intValue();
        if (i / intValue > 1024) {
            intValue *= 2;
            this.candidateScoreVsIterSubsampleFreq.put(Integer.valueOf(index), Integer.valueOf(intValue));
            IntArrayList intArrayList2 = new IntArrayList();
            FloatArrayList floatArrayList2 = new FloatArrayList();
            for (int i4 = 0; i4 < intArrayList.size(); i4++) {
                int intValue2 = intArrayList.get(i4).intValue();
                if (intValue2 % intValue == 0) {
                    intArrayList2.add(intValue2);
                    floatArrayList2.add(floatArrayList.get(i4));
                }
            }
            intArrayList = intArrayList2;
            floatArrayList = floatArrayList2;
            this.candidateScoreVsIter.put(Integer.valueOf(index), new Pair<>(intArrayList, floatArrayList));
        }
        if (i % intValue == 0) {
            intArrayList.add(i);
            floatArrayList.add((float) d);
        }
        int[] intArray = intArrayList.toIntArray();
        float[] fArr = new float[intArray.length];
        for (int i5 = 0; i5 < intArray.length; i5++) {
            fArr[i5] = floatArrayList.get(i5).floatValue();
        }
        ModelInfoPersistable build = new ModelInfoPersistable.Builder().timestamp(candidateInfo.getCreatedTime()).sessionId(this.sessionId).workerId(String.valueOf(candidateInfo.getIndex())).modelIdx(Integer.valueOf(candidateInfo.getIndex())).score(candidateInfo.getScore()).status(candidateInfo.getCandidateStatus()).scoreVsIter(intArray, fArr).lastUpdateTime(System.currentTimeMillis()).numParameters(j).numLayers(i2).totalNumUpdates(i3).paramSpaceValues(candidateInfo.getFlatParams()).modelConfigJson(str).exceptionStackTrace(candidateInfo.getExceptionStackTrace()).build();
        this.lastModelInfoPersistable.put(Integer.valueOf(candidateInfo.getIndex()), build);
        this.statsStorage.putUpdate(build);
    }

    private GlobalConfigPersistable getNewStatusPersistable(IOptimizationRunner iOptimizationRunner) {
        try {
            this.ocJson = JsonMapper.getMapper().writeValueAsString(iOptimizationRunner.getConfiguration());
            return new GlobalConfigPersistable.Builder().sessionId(this.sessionId).timestamp(System.currentTimeMillis()).optimizationConfigJson(this.ocJson).candidateCounts(iOptimizationRunner.numCandidatesQueued(), iOptimizationRunner.numCandidatesCompleted(), iOptimizationRunner.numCandidatesFailed(), iOptimizationRunner.numCandidatesTotal()).optimizationRunner(iOptimizationRunner.getClass().getSimpleName()).build();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
