package ai.libs.jaicore.search.algorithms.mdp.mcts;

import ai.libs.jaicore.graph.Graph;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.RolloutEvent;
import com.google.common.eventbus.Subscribe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/RolloutAnalyzer.class */
public class RolloutAnalyzer<N> {
    private Graph<N> explorationGraph = new Graph<>();
    private final Map<N, Integer> depths = new HashMap();
    private final Map<N, DescriptiveStatistics> currentScoreOfNodes = new HashMap();
    private final Map<N, Integer> iterationOfLastRollout = new HashMap();
    private final Map<N, Map<N, DescriptiveStatistics>> statsOfChildrenOfNodesAtTimeOfLastRolloutOfNodeWithLessRollouts = new HashMap();
    private final Map<N, Integer> iterationOfDecision = new HashMap();
    private final Map<N, List<N>> decisionLists = new HashMap();
    private int numRollouts = 0;

    @Subscribe
    public void receiveRolloutEvent(RolloutEvent<N, Double> rolloutEvent) {
        this.explorationGraph.addPath(rolloutEvent.getPath());
        this.numRollouts++;
        AtomicInteger atomicInteger = new AtomicInteger();
        N n = null;
        for (N n2 : rolloutEvent.getPath()) {
            this.depths.computeIfAbsent(n2, obj -> {
                return Integer.valueOf(atomicInteger.get());
            });
            this.currentScoreOfNodes.computeIfAbsent(n2, obj2 -> {
                return new DescriptiveStatistics();
            }).addValue(rolloutEvent.getScore().doubleValue());
            this.iterationOfLastRollout.put(n2, Integer.valueOf(this.numRollouts));
            atomicInteger.getAndIncrement();
            if (n != null) {
                Set successors = this.explorationGraph.getSuccessors(n);
                if (getChildrenOfNodesInOrderOfTheNumberOfVisits(n).get(0).equals(n2)) {
                    Map computeIfAbsent = this.statsOfChildrenOfNodesAtTimeOfLastRolloutOfNodeWithLessRollouts.computeIfAbsent(n, obj3 -> {
                        return new HashMap();
                    });
                    for (Object obj4 : successors) {
                        computeIfAbsent.put(obj4, this.currentScoreOfNodes.get(obj4).copy());
                    }
                    this.iterationOfDecision.put(n, Integer.valueOf(this.numRollouts));
                }
                this.decisionLists.computeIfAbsent(n, obj5 -> {
                    return new ArrayList();
                }).add(n2);
            }
            n = n2;
        }
    }

    public List<N> getMostVisistedSubPath(int i) {
        ArrayList arrayList = new ArrayList(i);
        arrayList.add(this.explorationGraph.getRoot());
        Object obj = arrayList.get(0);
        for (int i2 = 0; i2 < i; i2++) {
            Object obj2 = null;
            long j = 0;
            for (Object obj3 : this.explorationGraph.getSuccessors(obj)) {
                long n = this.currentScoreOfNodes.get(obj3).getN();
                if (n > j) {
                    j = n;
                    obj2 = obj3;
                }
            }
            obj = obj2;
            arrayList.add(obj);
        }
        return arrayList;
    }

    public Map<Integer, int[]> getVisitStatsOfMostVisitedChildrenPerDepth(int i, int i2) {
        List<N> mostVisistedSubPath = getMostVisistedSubPath(i);
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 <= i; i3++) {
            int min = Math.min(i3, i2);
            List<N> enumerateChildrenOfNodeUpToDepth = enumerateChildrenOfNodeUpToDepth(mostVisistedSubPath.get(i3 - min), min);
            int size = enumerateChildrenOfNodeUpToDepth.size();
            int[] iArr = new int[size];
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i4] = (int) this.currentScoreOfNodes.get(enumerateChildrenOfNodeUpToDepth.get(i4)).getN();
            }
            hashMap.put(Integer.valueOf(i3), iArr);
        }
        return hashMap;
    }

    public List<N> getChildrenOfNodesInOrderOfTheNumberOfVisits(N n) {
        return (List) this.explorationGraph.getSuccessors(n).stream().sorted((obj, obj2) -> {
            return Long.compare(this.currentScoreOfNodes.get(obj).getN(), this.currentScoreOfNodes.get(obj2).getN());
        }).collect(Collectors.toList());
    }

    public Map<Integer, int[]> getLatestRolloutAlongMostVisitedChildrenPerDepth(int i, int i2) {
        List<N> mostVisistedSubPath = getMostVisistedSubPath(i);
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 <= i; i3++) {
            int min = Math.min(i3, i2);
            List<N> enumerateChildrenOfNodeUpToDepth = enumerateChildrenOfNodeUpToDepth(mostVisistedSubPath.get(i3 - min), min);
            int size = enumerateChildrenOfNodeUpToDepth.size();
            int[] iArr = new int[size];
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i4] = this.iterationOfLastRollout.get(enumerateChildrenOfNodeUpToDepth.get(i4)).intValue();
            }
            hashMap.put(Integer.valueOf(i3), iArr);
        }
        return hashMap;
    }

    public Map<Integer, Map<N, DescriptiveStatistics>> getChildrenStatisticsAtPointOfDecisionOfMostVisitedPathPerDepth(int i) {
        List<N> mostVisistedSubPath = getMostVisistedSubPath(i);
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 <= i; i2++) {
            hashMap.put(Integer.valueOf(i2), this.statsOfChildrenOfNodesAtTimeOfLastRolloutOfNodeWithLessRollouts.get(mostVisistedSubPath.get(i2)));
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public List<N> enumerateChildrenOfNodeUpToDepth(N n, int i) {
        if (i == 0) {
            return Arrays.asList(n);
        }
        ArrayList arrayList = new ArrayList((int) Math.pow(2.0d, i));
        Iterator it = this.explorationGraph.getSuccessors(n).iterator();
        while (it.hasNext()) {
            arrayList.addAll(enumerateChildrenOfNodeUpToDepth(it.next(), i - 1));
        }
        return arrayList;
    }

    public Map<Integer, DescriptiveStatistics> getVisitStatsPerDepth() {
        HashMap hashMap = new HashMap();
        Iterator it = this.explorationGraph.getItems().iterator();
        while (it.hasNext()) {
            ((DescriptiveStatistics) hashMap.computeIfAbsent(this.depths.get(it.next()), num -> {
                return new DescriptiveStatistics();
            })).addValue(this.currentScoreOfNodes.get(r0).getN());
        }
        return hashMap;
    }

    public Map<N, Integer> getIterationOfDecision() {
        return this.iterationOfDecision;
    }

    public List<N> getDecisionListForNode(N n) {
        return this.decisionLists.get(n);
    }
}
