package tl.lin.data.cfd;

import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.lang.Comparable;
import java.util.Iterator;
import tl.lin.data.fd.Object2IntFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistributionFastutil;
import tl.lin.data.fd.Object2LongFrequencyDistribution;
import tl.lin.data.fd.Object2LongFrequencyDistributionFastutil;
import tl.lin.data.pair.PairOfObjectInt;

/* loaded from: input_file:tl/lin/data/cfd/Object2IntConditionalFrequencyDistributionFastutil.class */
public class Object2IntConditionalFrequencyDistributionFastutil<K extends Comparable<K>> implements Object2IntConditionalFrequencyDistribution<K> {
    private final Object2ObjectMap<K, Object2IntFrequencyDistribution<K>> distributions = new Object2ObjectOpenHashMap();
    private final Object2LongFrequencyDistribution<K> marginals = new Object2LongFrequencyDistributionFastutil();
    private long sumOfAllCounts = 0;

    public void set(K k, K k2, int i) {
        if (!this.distributions.containsKey(k2)) {
            Object2IntFrequencyDistributionFastutil object2IntFrequencyDistributionFastutil = new Object2IntFrequencyDistributionFastutil();
            object2IntFrequencyDistributionFastutil.set(k, i);
            this.distributions.put(k2, object2IntFrequencyDistributionFastutil);
            this.marginals.increment(k, i);
            this.sumOfAllCounts += i;
            return;
        }
        Object2IntFrequencyDistribution object2IntFrequencyDistribution = (Object2IntFrequencyDistribution) this.distributions.get(k2);
        int i2 = object2IntFrequencyDistribution.get(k);
        object2IntFrequencyDistribution.set(k, i);
        this.distributions.put(k2, object2IntFrequencyDistribution);
        this.marginals.increment(k, (-i2) + i);
        this.sumOfAllCounts = (this.sumOfAllCounts - i2) + i;
    }

    public void increment(K k, K k2) {
        increment(k, k2, 1);
    }

    public void increment(K k, K k2, int i) {
        int i2 = get(k, k2);
        if (i2 == 0) {
            set(k, k2, i);
        } else {
            set(k, k2, i2 + i);
        }
    }

    public int get(K k, K k2) {
        if (this.distributions.containsKey(k2)) {
            return ((Object2IntFrequencyDistribution) this.distributions.get(k2)).get(k);
        }
        return 0;
    }

    public long getMarginalCount(K k) {
        return this.marginals.get(k);
    }

    public Object2IntFrequencyDistribution<K> getConditionalDistribution(K k) {
        return this.distributions.containsKey(k) ? (Object2IntFrequencyDistribution) this.distributions.get(k) : new Object2IntFrequencyDistributionFastutil();
    }

    public long getSumOfAllCounts() {
        return this.sumOfAllCounts;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void check() {
        Object2IntFrequencyDistributionFastutil object2IntFrequencyDistributionFastutil = new Object2IntFrequencyDistributionFastutil();
        long j = 0;
        ObjectIterator it = this.distributions.values().iterator();
        while (it.hasNext()) {
            Object2IntFrequencyDistribution<PairOfObjectInt> object2IntFrequencyDistribution = (Object2IntFrequencyDistribution) it.next();
            long j2 = 0;
            for (PairOfObjectInt pairOfObjectInt : object2IntFrequencyDistribution) {
                j2 += pairOfObjectInt.getRightElement();
                object2IntFrequencyDistributionFastutil.increment(pairOfObjectInt.getLeftElement(), pairOfObjectInt.getRightElement());
            }
            if (j2 != object2IntFrequencyDistribution.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            j += object2IntFrequencyDistribution.getSumOfCounts();
        }
        if (j != getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + j + ", Expected " + getSumOfAllCounts());
        }
        Iterator it2 = object2IntFrequencyDistributionFastutil.iterator();
        while (it2.hasNext()) {
            if (r0.getRightElement() != this.marginals.get(((PairOfObjectInt) it2.next()).getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
        Iterator it3 = object2IntFrequencyDistributionFastutil.iterator();
        while (it3.hasNext()) {
            PairOfObjectInt pairOfObjectInt2 = (PairOfObjectInt) it3.next();
            if (pairOfObjectInt2.getRightElement() != object2IntFrequencyDistributionFastutil.get(pairOfObjectInt2.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
    }
}
