package org.neo4j.gds.decisiontree;

import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;

/* loaded from: input_file:org/neo4j/gds/decisiontree/GiniIndex.class */
public class GiniIndex implements DecisionTreeLoss {
    private final HugeIntArray expectedMappedLabels;
    private final int numberOfClasses;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GiniIndex(HugeIntArray hugeIntArray, int i) {
        this.expectedMappedLabels = hugeIntArray;
        this.numberOfClasses = i;
    }

    public static GiniIndex fromOriginalLabels(HugeLongArray hugeLongArray, LocalIdMap localIdMap) {
        if (!$assertionsDisabled && hugeLongArray.size() <= 0) {
            throw new AssertionError();
        }
        HugeIntArray newArray = HugeIntArray.newArray(hugeLongArray.size());
        newArray.setAll(j -> {
            return localIdMap.toMapped(hugeLongArray.get(j));
        });
        return new GiniIndex(newArray, localIdMap.size());
    }

    @Override // org.neo4j.gds.decisiontree.DecisionTreeLoss
    public double splitLoss(Groups groups, GroupSizes groupSizes) {
        long left = groupSizes.left() + groupSizes.right();
        if (left == 0) {
            throw new IllegalStateException("Cannot compute loss over only empty groups");
        }
        return (computeGroupLoss(groups.left(), groupSizes.left()) + computeGroupLoss(groups.right(), groupSizes.right())) / left;
    }

    private double computeGroupLoss(HugeLongArray hugeLongArray, long j) {
        if (!$assertionsDisabled && hugeLongArray.size() < j) {
            throw new AssertionError();
        }
        if (j == 0) {
            return 0.0d;
        }
        long[] jArr = new long[this.numberOfClasses];
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                break;
            }
            int i = this.expectedMappedLabels.get(hugeLongArray.get(j3));
            jArr[i] = jArr[i] + 1;
            j2 = j3 + 1;
        }
        long j4 = 0;
        for (long j5 : jArr) {
            j4 += j5 * j5;
        }
        return j - (j4 / j);
    }

    static {
        $assertionsDisabled = !GiniIndex.class.desiredAssertionStatus();
    }
}
