package org.tribuo;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.MutableNumber;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/CategoricalInfo.class */
public class CategoricalInfo extends SkeletalVariableInfo {
    private static final long serialVersionUID = 2;
    private static final MutableLong ZERO = new MutableLong(0);
    public static final int THRESHOLD = 50;
    private static final double COMPARISON_THRESHOLD = 1.0E-10d;
    protected Map<Double, MutableLong> valueCounts;
    protected double observedValue;
    protected long observedCount;
    protected transient double[] values;
    protected transient long totalObservations;
    protected transient double[] cdf;

    public CategoricalInfo(String str) {
        super(str);
        this.valueCounts = null;
        this.observedValue = Double.NaN;
        this.observedCount = 0L;
        this.values = null;
        this.totalObservations = -1L;
        this.cdf = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CategoricalInfo(CategoricalInfo categoricalInfo) {
        this(categoricalInfo, categoricalInfo.name);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CategoricalInfo(CategoricalInfo categoricalInfo, String str) {
        super(str, categoricalInfo.count);
        this.valueCounts = null;
        this.observedValue = Double.NaN;
        this.observedCount = 0L;
        this.values = null;
        this.totalObservations = -1L;
        this.cdf = null;
        if (categoricalInfo.valueCounts != null) {
            this.valueCounts = MutableNumber.copyMap(categoricalInfo.valueCounts);
        } else {
            this.observedValue = categoricalInfo.observedValue;
            this.observedCount = categoricalInfo.observedCount;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tribuo.SkeletalVariableInfo
    public void observe(double d) {
        if (d != 0.0d) {
            super.observe(d);
            if (this.valueCounts != null) {
                this.valueCounts.computeIfAbsent(Double.valueOf(d), d2 -> {
                    return new MutableLong();
                }).increment();
            } else if (Double.isNaN(this.observedValue)) {
                this.observedValue = d;
                this.observedCount++;
            } else if (Math.abs(d - this.observedValue) < COMPARISON_THRESHOLD) {
                this.observedCount++;
            } else {
                this.valueCounts = new HashMap(4);
                this.valueCounts.put(Double.valueOf(this.observedValue), new MutableLong(this.observedCount));
                this.valueCounts.put(Double.valueOf(d), new MutableLong(1L));
                this.observedValue = Double.NaN;
                this.observedCount = 0L;
            }
            this.values = null;
        }
    }

    public long getObservationCount(double d) {
        if (this.valueCounts != null) {
            return this.valueCounts.getOrDefault(Double.valueOf(d), ZERO).longValue();
        }
        if (Math.abs(d - this.observedValue) < COMPARISON_THRESHOLD) {
            return this.observedCount;
        }
        return 0L;
    }

    public int getUniqueObservations() {
        return this.valueCounts != null ? this.valueCounts.size() : Double.isNaN(this.observedValue) ? 0 : 1;
    }

    public RealInfo generateRealInfo() {
        double d;
        double d2 = Double.POSITIVE_INFINITY;
        double d3 = Double.NEGATIVE_INFINITY;
        double d4 = 0.0d;
        double d5 = 0.0d;
        if (this.valueCounts != null) {
            List<Map.Entry> list = (List) this.valueCounts.entrySet().stream().sorted(Comparator.comparingDouble((v0) -> {
                return v0.getKey();
            })).collect(Collectors.toList());
            for (Map.Entry entry : list) {
                double doubleValue = ((Double) entry.getKey()).doubleValue();
                double longValue = ((MutableLong) entry.getValue()).longValue();
                if (doubleValue > d3) {
                    d3 = doubleValue;
                }
                if (doubleValue < d2) {
                    d2 = doubleValue;
                }
                d4 += doubleValue * longValue;
            }
            d = d4 / this.count;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                double doubleValue2 = ((Double) ((Map.Entry) it.next()).getKey()).doubleValue();
                d5 += (doubleValue2 - d) * (doubleValue2 - d) * ((MutableLong) r0.getValue()).longValue();
            }
        } else {
            d2 = this.observedValue;
            d3 = this.observedValue;
            d = this.observedValue;
            d5 = 0.0d;
        }
        return new RealInfo(this.name, this.count, d3, d2, d, d5);
    }

    @Override // org.tribuo.VariableInfo
    public CategoricalInfo copy() {
        return new CategoricalInfo(this);
    }

    @Override // org.tribuo.VariableInfo
    public CategoricalIDInfo makeIDInfo(int i) {
        return new CategoricalIDInfo(this, i);
    }

    @Override // org.tribuo.VariableInfo
    public CategoricalInfo rename(String str) {
        return new CategoricalInfo(this, str);
    }

    @Override // org.tribuo.VariableInfo
    public synchronized double uniformSample(SplittableRandom splittableRandom) {
        if (this.values == null) {
            regenerateValues();
        }
        return this.values[splittableRandom.nextInt(this.values.length)];
    }

    public double frequencyBasedSample(SplittableRandom splittableRandom, long j) {
        if (j != this.totalObservations || this.cdf == null) {
            regenerateCDF(j);
        }
        return this.values[Util.sampleFromCDF(this.cdf, splittableRandom)];
    }

    public double frequencyBasedSample(Random random, long j) {
        if (j != this.totalObservations || this.cdf == null) {
            regenerateCDF(j);
        }
        return this.values[Util.sampleFromCDF(this.cdf, random)];
    }

    private synchronized void regenerateCDF(long j) {
        long[] jArr;
        if (this.valueCounts != null) {
            if (this.valueCounts.containsKey(Double.valueOf(0.0d))) {
                this.values = new double[this.valueCounts.size()];
                jArr = new long[this.valueCounts.size()];
            } else {
                this.values = new double[this.valueCounts.size() + 1];
                jArr = new long[this.valueCounts.size() + 1];
            }
            this.values[0] = 0.0d;
            jArr[0] = j;
            int i = 1;
            long j2 = 0;
            for (Map.Entry entry : (List) this.valueCounts.entrySet().stream().sorted(Comparator.comparingDouble((v0) -> {
                return v0.getKey();
            })).collect(Collectors.toList())) {
                if (((Double) entry.getKey()).doubleValue() != 0.0d) {
                    this.values[i] = ((Double) entry.getKey()).doubleValue();
                    jArr[i] = ((MutableLong) entry.getValue()).longValue();
                    j2 += jArr[i];
                    i++;
                }
            }
            long[] jArr2 = jArr;
            jArr2[0] = jArr2[0] - j2;
        } else if (Double.isNaN(this.observedValue) || this.observedValue == 0.0d) {
            this.values = new double[1];
            this.values[0] = 0.0d;
            jArr = new long[]{j};
        } else {
            this.values = new double[2];
            this.values[0] = 0.0d;
            this.values[1] = this.observedValue;
            jArr = new long[]{j - this.observedCount, this.observedCount};
        }
        long j3 = 0;
        for (long j4 : jArr) {
            j3 += j4;
        }
        if (j3 != j) {
            throw new IllegalStateException("Total counts = " + j3 + ", supplied value = " + j);
        }
        this.cdf = Util.generateCDF(jArr, j3);
        this.totalObservations = j;
    }

    private synchronized void regenerateValues() {
        int i;
        if (this.valueCounts == null) {
            if (Double.isNaN(this.observedValue) || this.observedValue == 0.0d) {
                this.values = new double[1];
                this.values[0] = 0.0d;
                return;
            } else {
                this.values = new double[2];
                this.values[0] = 0.0d;
                this.values[1] = this.observedValue;
                return;
            }
        }
        if (this.valueCounts.containsKey(Double.valueOf(0.0d))) {
            this.values = new double[this.valueCounts.size()];
            i = 0;
        } else {
            this.values = new double[this.valueCounts.size() + 1];
            this.values[0] = 0.0d;
            i = 1;
        }
        Iterator it = ((List) this.valueCounts.keySet().stream().sorted().collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            this.values[i] = ((Double) it.next()).doubleValue();
            i++;
        }
    }

    @Override // org.tribuo.SkeletalVariableInfo
    public String toString() {
        return this.valueCounts != null ? "CategoricalFeature(name=" + this.name + ",count=" + this.count + ",map=" + this.valueCounts.toString() + ")" : "CategoricalFeature(name=" + this.name + ",count=" + this.count + ",map={" + this.observedValue + "," + this.observedCount + "})";
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this.totalObservations = -1L;
        this.values = null;
        this.cdf = null;
    }
}
