package org.mitre.caasd.commons.collect;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import org.mitre.caasd.commons.Pair;

/* loaded from: input_file:org/mitre/caasd/commons/collect/MetricSet.class */
public class MetricSet<K> implements Serializable {
    private static final long serialVersionUID = 1;
    private static final int DEFAULT_SPHERE_SIZE = 50;
    private final CenterPointSelector<K> centerPointSelector;
    private final DistanceMetric<K> metric;
    private MetricSet<K>.Sphere rootSphere;
    private int sphereCount;
    private final int MAX_INNER_SPHERE_SIZE;
    private HashMap<K, MetricSet<K>.Sphere> globalHashMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/mitre/caasd/commons/collect/MetricSet$Sphere.class */
    public class Sphere implements Serializable {
        final K centerPoint;
        private double radius;
        private SphereType type = SphereType.SPHERE_OF_POINTS;
        private Set<K> entries = new HashSet();
        private Pair<MetricSet<K>.Sphere, MetricSet<K>.Sphere> childSpheres = null;

        Sphere(K k) {
            this.centerPoint = k;
            MetricSet.access$008(MetricSet.this);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double radius() {
            return this.radius;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public boolean isSphereOfPoints() {
            return this.type == SphereType.SPHERE_OF_POINTS;
        }

        boolean isSphereOfSpheres() {
            return this.type == SphereType.SPHERE_OF_SPHERES;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Set<K> points() {
            return this.entries;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Pair<MetricSet<K>.Sphere, MetricSet<K>.Sphere> children() {
            return this.childSpheres;
        }

        boolean add(K k) {
            if (isFull()) {
                split();
            }
            this.radius = Math.max(this.radius, MetricSet.this.verifiedDistance(this.centerPoint, k));
            if (isSphereOfPoints()) {
                MetricSet.this.globalHashMap.put(k, this);
                return this.entries.add(k);
            }
            if (isSphereOfSpheres()) {
                return findClosestChildSphere(k).add(k);
            }
            throw new AssertionError("Should never get here, all SphereTypes covered");
        }

        private void split() {
            Pair<MetricSet<K>.Sphere, MetricSet<K>.Sphere> splitSphereOfPoints = splitSphereOfPoints();
            this.type = SphereType.SPHERE_OF_SPHERES;
            this.entries = null;
            this.childSpheres = splitSphereOfPoints;
        }

        boolean remove(K k) {
            if (isSphereOfPoints()) {
                return this.entries.remove(k);
            }
            throw new AssertionError("Should never get here.  This should only be called on \"Sphere of Points\"");
        }

        private boolean isFull() {
            return this.type == SphereType.SPHERE_OF_POINTS && this.entries.size() >= MetricSet.this.MAX_INNER_SPHERE_SIZE;
        }

        private MetricSet<K>.Sphere findClosestChildSphere(K k) {
            return MetricSet.this.verifiedDistance(k, this.childSpheres.first().centerPoint) < MetricSet.this.verifiedDistance(k, this.childSpheres.second().centerPoint) ? this.childSpheres.first() : this.childSpheres.second();
        }

        private Pair<MetricSet<K>.Sphere, MetricSet<K>.Sphere> splitSphereOfPoints() {
            Preconditions.checkState(this.type == SphereType.SPHERE_OF_POINTS, "Only SPHERE_OF_POINTS should be split");
            Pair<K, K> pickCentersForNewSpheres = pickCentersForNewSpheres();
            MetricSet<K>.Sphere sphere = new Sphere(pickCentersForNewSpheres.first());
            MetricSet<K>.Sphere sphere2 = new Sphere(pickCentersForNewSpheres.second());
            moveEntriesToChildren(sphere, sphere2);
            return new Pair<>(sphere, sphere2);
        }

        private Pair<K, K> pickCentersForNewSpheres() {
            return MetricSet.this.centerPointSelector.selectNewCenterPoints(Lists.newArrayList(this.entries), MetricSet.this.metric);
        }

        private void moveEntriesToChildren(MetricSet<K>.Sphere sphere, MetricSet<K>.Sphere sphere2) {
            boolean z = false;
            Iterator<K> it = this.entries.iterator();
            while (it.hasNext()) {
                addToBestOf(sphere, sphere2, it.next(), z);
                z = !z;
            }
        }

        private void addToBestOf(MetricSet<K>.Sphere sphere, MetricSet<K>.Sphere sphere2, K k, boolean z) {
            MetricSet<K>.Sphere sphere3;
            double verifiedDistance = MetricSet.this.verifiedDistance(k, sphere.centerPoint);
            double verifiedDistance2 = MetricSet.this.verifiedDistance(k, sphere2.centerPoint);
            if (verifiedDistance == verifiedDistance2) {
                sphere3 = z ? sphere : sphere2;
            } else {
                sphere3 = verifiedDistance < verifiedDistance2 ? sphere : sphere2;
            }
            sphere3.add(k);
        }

        Set<K> entries() {
            if (isSphereOfPoints()) {
                return entries();
            }
            if (!isSphereOfSpheres()) {
                throw new AssertionError("Should never get here, all SphereTypes covered");
            }
            HashSet hashSet = new HashSet();
            hashSet.addAll(this.childSpheres.first().entries());
            hashSet.addAll(this.childSpheres.second().entries());
            return hashSet;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/mitre/caasd/commons/collect/MetricSet$SphereType.class */
    public enum SphereType {
        SPHERE_OF_POINTS,
        SPHERE_OF_SPHERES
    }

    public MetricSet(DistanceMetric<K> distanceMetric) {
        this(distanceMetric, DEFAULT_SPHERE_SIZE);
        new TreeMap();
    }

    public MetricSet(DistanceMetric<K> distanceMetric, int i) {
        this(distanceMetric, i, CenterPointSelectors.maxOfRandomSamples());
    }

    public MetricSet(DistanceMetric<K> distanceMetric, int i, CenterPointSelector<K> centerPointSelector) {
        this.sphereCount = 0;
        this.globalHashMap = new HashMap<>();
        Preconditions.checkNotNull(distanceMetric, "The input DistanceMetric cannot be null");
        Preconditions.checkArgument(i >= 4, "The maxSphereSize must be at least 4, it was: " + i);
        Preconditions.checkNotNull(centerPointSelector, "The CenterPointSelector cannot be null");
        this.metric = distanceMetric;
        this.centerPointSelector = centerPointSelector;
        this.MAX_INNER_SPHERE_SIZE = i;
    }

    public final DistanceMetric<K> metric() {
        return this.metric;
    }

    public boolean add(K k) {
        Preconditions.checkNotNull(k);
        if (this.rootSphere == null) {
            this.rootSphere = new Sphere(k);
        }
        if (this.globalHashMap.containsKey(k)) {
            return false;
        }
        this.rootSphere.add(k);
        return true;
    }

    public boolean addAll(Collection<K> collection) {
        boolean z = false;
        Iterator<K> it = collection.iterator();
        while (it.hasNext()) {
            z |= add(it.next());
        }
        return z;
    }

    public int size() {
        return this.globalHashMap.size();
    }

    public boolean isEmpty() {
        return this.globalHashMap.isEmpty();
    }

    public boolean contains(K k) {
        return this.globalHashMap.containsKey(k);
    }

    public SetSearchResult<K> getClosest(K k) {
        if (!this.globalHashMap.containsKey(k)) {
            return (SetSearchResult) new ArrayList(getNClosest(k, 1)).get(0);
        }
        this.globalHashMap.get(k);
        return new SetSearchResult<>(k, 0.0d);
    }

    public List<SetSearchResult<K>> getNClosest(K k, int i) {
        Preconditions.checkArgument(Objects.nonNull(k));
        if (i < 1) {
            throw new IllegalArgumentException("n must be at least 1");
        }
        if (isEmpty()) {
            return Collections.emptyList();
        }
        SetSearch setSearch = new SetSearch(k, i, this.metric);
        setSearch.startQuery(this.rootSphere);
        ArrayList arrayList = new ArrayList(setSearch.results());
        Collections.sort(arrayList);
        return arrayList;
    }

    public List<SetSearchResult<K>> getAllWithinRange(K k, double d) {
        Preconditions.checkArgument(Objects.nonNull(k));
        if (d <= 0.0d) {
            throw new IllegalArgumentException("The range must be strictly positive " + d);
        }
        if (isEmpty()) {
            return Collections.emptyList();
        }
        SetSearch setSearch = new SetSearch(k, this.metric, d);
        setSearch.startQuery(this.rootSphere);
        ArrayList arrayList = new ArrayList(setSearch.results());
        Collections.sort(arrayList);
        return arrayList;
    }

    public boolean remove(K k) {
        Preconditions.checkArgument(Objects.nonNull(k));
        MetricSet<K>.Sphere remove = this.globalHashMap.remove(k);
        if (remove == null) {
            return false;
        }
        boolean remove2 = remove.remove(k);
        if (remove2) {
            return remove2;
        }
        throw new AssertionError("Unexpected state, hadImpact should always be true here becuase the key was found in the global map");
    }

    public void clear() {
        this.rootSphere = null;
        this.globalHashMap = new HashMap<>();
        this.sphereCount = 0;
    }

    public Set<K> keySet() {
        return this.globalHashMap.keySet();
    }

    public int sphereCount() {
        return this.sphereCount;
    }

    public MetricSet<K> makeBalancedCopy() {
        ArrayList newArrayList = Lists.newArrayList(keySet());
        Collections.shuffle(newArrayList);
        MetricSet<K> metricSet = new MetricSet<>(this.metric);
        metricSet.addAll(newArrayList);
        if (size() != metricSet.size()) {
            throw new AssertionError("The rebalancing process changed the number of entries");
        }
        return metricSet;
    }

    public void rebalance() {
        MetricSet<K> makeBalancedCopy = makeBalancedCopy();
        this.rootSphere = makeBalancedCopy.rootSphere;
        this.globalHashMap = makeBalancedCopy.globalHashMap;
        this.sphereCount = makeBalancedCopy.sphereCount;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double verifiedDistance(K k, K k2) {
        double distanceBtw = this.metric.distanceBtw(k, k2);
        Preconditions.checkState(!Double.isNaN(distanceBtw), "A distance measurement was NaN.");
        Preconditions.checkState(distanceBtw >= 0.0d, "A negative distance measurement was observed.");
        return distanceBtw;
    }

    static /* synthetic */ int access$008(MetricSet metricSet) {
        int i = metricSet.sphereCount;
        metricSet.sphereCount = i + 1;
        return i;
    }
}
