package org.grouplens.lenskit.vectors;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import it.unimi.dsi.fastutil.Swapper;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.ints.AbstractIntComparator;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongComparators;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.objects.Reference2ObjectArrayMap;
import java.io.Serializable;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.grouplens.lenskit.collections.BitSetIterator;
import org.grouplens.lenskit.collections.LongSortedArraySet;
import org.grouplens.lenskit.collections.MoreArrays;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.symbols.TypedSymbol;

/* loaded from: input_file:org/grouplens/lenskit/vectors/MutableSparseVector.class */
public final class MutableSparseVector extends SparseVector implements Serializable {
    private static final long serialVersionUID = 2;

    @SuppressFBWarnings(value = {"SE_BAD_FIELD"}, justification = "stored value is always serializable")
    private final Map<Symbol, MutableSparseVector> channelMap;

    @SuppressFBWarnings(value = {"SE_BAD_FIELD"}, justification = "stored value is always serializable")
    private final Map<TypedSymbol<?>, TypedSideChannel<?>> typedChannelMap;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/grouplens/lenskit/vectors/MutableSparseVector$IdComparator.class */
    private static class IdComparator extends AbstractIntComparator {
        private long[] keys;

        public IdComparator(long[] jArr) {
            this.keys = jArr;
        }

        public int compare(int i, int i2) {
            return LongComparators.NATURAL_COMPARATOR.compare(this.keys[i], this.keys[i2]);
        }
    }

    /* loaded from: input_file:org/grouplens/lenskit/vectors/MutableSparseVector$ParallelSwapper.class */
    private static class ParallelSwapper implements Swapper {
        private long[] keys;
        private double[] values;

        public ParallelSwapper(long[] jArr, double[] dArr) {
            this.keys = jArr;
            this.values = dArr;
        }

        public void swap(int i, int i2) {
            long j = this.keys[i];
            this.keys[i] = this.keys[i2];
            this.keys[i2] = j;
            double d = this.values[i];
            this.values[i] = this.values[i2];
            this.values[i2] = d;
        }
    }

    public MutableSparseVector() {
        this(new long[0], new double[0]);
    }

    public MutableSparseVector(Long2DoubleMap long2DoubleMap) {
        super(long2DoubleMap);
        this.channelMap = new Reference2ObjectArrayMap();
        this.typedChannelMap = new Reference2ObjectArrayMap();
    }

    public MutableSparseVector(Collection<Long> collection) {
        super(collection);
        this.channelMap = new Reference2ObjectArrayMap();
        this.typedChannelMap = new Reference2ObjectArrayMap();
    }

    public MutableSparseVector(LongSet longSet, double d) {
        this((Collection<Long>) longSet);
        DoubleArrays.fill(this.values, 0, this.domainSize, d);
        this.usedKeys.set(0, this.domainSize);
    }

    protected MutableSparseVector(long[] jArr, double[] dArr) {
        this(jArr, dArr, jArr.length);
    }

    protected MutableSparseVector(long[] jArr, double[] dArr, int i) {
        super(jArr, dArr, i);
        this.channelMap = new Reference2ObjectArrayMap();
        this.typedChannelMap = new Reference2ObjectArrayMap();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MutableSparseVector(long[] jArr, double[] dArr, int i, BitSet bitSet) {
        super(jArr, dArr, i, bitSet);
        this.channelMap = new Reference2ObjectArrayMap();
        this.typedChannelMap = new Reference2ObjectArrayMap();
    }

    protected MutableSparseVector(long[] jArr, double[] dArr, int i, BitSet bitSet, Map<Symbol, MutableSparseVector> map, Map<TypedSymbol<?>, TypedSideChannel<?>> map2) {
        super(jArr, dArr, i, bitSet);
        this.channelMap = map;
        this.typedChannelMap = map2;
    }

    public MutableSparseVector withDomain(LongSet longSet) {
        MutableSparseVector mutableSparseVector = new MutableSparseVector((Collection<Long>) longSet);
        mutableSparseVector.set(this);
        for (Map.Entry<Symbol, MutableSparseVector> entry : this.channelMap.entrySet()) {
            mutableSparseVector.addChannel(entry.getKey(), entry.getValue().withDomain(longSet));
        }
        for (Map.Entry<TypedSymbol<?>, TypedSideChannel<?>> entry2 : this.typedChannelMap.entrySet()) {
            mutableSparseVector.addChannel(entry2.getKey(), entry2.getValue().withDomain(longSet));
        }
        return mutableSparseVector;
    }

    public MutableSparseVector shrinkDomain() {
        LongArraySet longArraySet = new LongArraySet();
        Iterator<VectorEntry> it = iterator();
        while (it.hasNext()) {
            longArraySet.add(it.next().getKey());
        }
        return withDomain(longArraySet);
    }

    private void checkFrozen() {
        if (this.values == null) {
            throw new IllegalStateException("The mutable sparse vector is frozen");
        }
    }

    private double setAt(int i, double d) {
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError();
        }
        double d2 = this.usedKeys.get(i) ? this.values[i] : Double.NaN;
        this.values[i] = d;
        this.usedKeys.set(i);
        return d2;
    }

    public double set(long j, double d) {
        checkFrozen();
        int findIndex = findIndex(j);
        if (findIndex < 0) {
            throw new IllegalArgumentException("Cannot 'set' key=" + j + " that is not in the key domain.");
        }
        return setAt(findIndex, d);
    }

    public double set(VectorEntry vectorEntry, double d) {
        checkFrozen();
        SparseVector vector = vectorEntry.getVector();
        int index = vectorEntry.getIndex();
        if (vector == null) {
            throw new IllegalArgumentException("entry is not associated with a vector");
        }
        if (vector.keys != this.keys) {
            throw new IllegalArgumentException("entry does not have safe key domain");
        }
        if (index < 0) {
            throw new IllegalArgumentException("Cannot 'set' a key with a negative index.");
        }
        if (vectorEntry.getKey() != this.keys[index]) {
            throw new IllegalArgumentException("entry does not have the correct key for its index");
        }
        if (vector == this) {
            vectorEntry.setValue(d);
        }
        return setAt(index, d);
    }

    public void fill(double d) {
        checkFrozen();
        DoubleArrays.fill(this.values, 0, this.domainSize, d);
        this.usedKeys.set(0, this.domainSize);
    }

    @Deprecated
    public void clear(long j) {
        unset(j);
    }

    @Deprecated
    public void clear(VectorEntry vectorEntry) {
        unset(vectorEntry);
    }

    public void unset(long j) {
        checkFrozen();
        int findIndex = findIndex(j);
        if (findIndex < 0) {
            throw new IllegalArgumentException("unset should only be used on keys that are in the key domain");
        }
        this.usedKeys.clear(findIndex);
    }

    public void unset(VectorEntry vectorEntry) {
        if (vectorEntry.getVector() != this) {
            throw new IllegalArgumentException("clearing vector from wrong entry");
        }
        checkFrozen();
        this.usedKeys.clear(vectorEntry.getIndex());
    }

    public void clear() {
        this.usedKeys.clear();
    }

    public double add(long j, double d) {
        checkFrozen();
        int findIndex = findIndex(j);
        if (findIndex < 0 || !this.usedKeys.get(findIndex)) {
            return Double.NaN;
        }
        double[] dArr = this.values;
        dArr[findIndex] = dArr[findIndex] + d;
        return this.values[findIndex];
    }

    public void add(double d) {
        checkFrozen();
        for (int i = 0; i < this.domainSize; i++) {
            double[] dArr = this.values;
            int i2 = i;
            dArr[i2] = dArr[i2] + d;
        }
    }

    public void subtract(SparseVector sparseVector) {
        checkFrozen();
        int i = 0;
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            while (i < this.domainSize && this.keys[i] < key) {
                i++;
            }
            if (i >= this.domainSize) {
                return;
            }
            if (this.keys[i] == key && this.usedKeys.get(i)) {
                double[] dArr = this.values;
                int i2 = i;
                dArr[i2] = dArr[i2] - vectorEntry.getValue();
            }
        }
    }

    public void add(SparseVector sparseVector) {
        checkFrozen();
        int i = 0;
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            while (i < this.domainSize && this.keys[i] < key) {
                i++;
            }
            if (i >= this.domainSize) {
                return;
            }
            if (this.keys[i] == key && this.usedKeys.get(i)) {
                double[] dArr = this.values;
                int i2 = i;
                dArr[i2] = dArr[i2] + vectorEntry.getValue();
            }
        }
    }

    public void set(SparseVector sparseVector) {
        checkFrozen();
        int i = 0;
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            while (i < this.domainSize && this.keys[i] < key) {
                i++;
            }
            if (i >= this.domainSize) {
                return;
            }
            if (this.keys[i] == key) {
                this.values[i] = vectorEntry.getValue();
                this.usedKeys.set(i);
            }
        }
    }

    public void scale(double d) {
        checkFrozen();
        BitSetIterator bitSetIterator = new BitSetIterator(this.usedKeys, 0, this.domainSize);
        while (bitSetIterator.hasNext()) {
            int nextInt = bitSetIterator.nextInt();
            double[] dArr = this.values;
            dArr[nextInt] = dArr[nextInt] * d;
        }
    }

    public MutableSparseVector copy() {
        return mutableCopy();
    }

    private Map<Symbol, MutableSparseVector> copyOfChannelMap() {
        Reference2ObjectArrayMap reference2ObjectArrayMap = new Reference2ObjectArrayMap();
        for (Map.Entry<Symbol, MutableSparseVector> entry : this.channelMap.entrySet()) {
            reference2ObjectArrayMap.put(entry.getKey(), entry.getValue().copy());
        }
        return reference2ObjectArrayMap;
    }

    private Map<TypedSymbol<?>, TypedSideChannel<?>> copyOfTypedChannelMap() {
        Reference2ObjectArrayMap reference2ObjectArrayMap = new Reference2ObjectArrayMap();
        for (Map.Entry<TypedSymbol<?>, TypedSideChannel<?>> entry : this.typedChannelMap.entrySet()) {
            reference2ObjectArrayMap.put(entry.getKey(), entry.getValue().mutableCopy());
        }
        return reference2ObjectArrayMap;
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public MutableSparseVector mutableCopy() {
        checkFrozen();
        return new MutableSparseVector(this.keys, Arrays.copyOf(this.values, this.domainSize), this.domainSize, (BitSet) this.usedKeys.clone(), copyOfChannelMap(), copyOfTypedChannelMap());
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public ImmutableSparseVector immutable() {
        return immutable(false);
    }

    public ImmutableSparseVector freeze() {
        return immutable(true);
    }

    public ImmutableSparseVector immutable(boolean z) {
        checkFrozen();
        return immutable(z, (z && this.usedKeys.cardinality() == this.keys.length) ? this.keys : keySet().toLongArray());
    }

    private ImmutableSparseVector immutable(boolean z, long[] jArr) {
        double[] dArr;
        BitSet bitSet;
        if (jArr == this.keys && z) {
            dArr = this.values;
            bitSet = this.usedKeys;
        } else {
            dArr = new double[jArr.length];
            bitSet = new BitSet(jArr.length);
            int i = 0;
            int i2 = 0;
            while (i < dArr.length && i2 < this.domainSize) {
                if (jArr[i] == this.keys[i2]) {
                    dArr[i] = this.values[i2];
                    if (this.usedKeys.get(i2)) {
                        bitSet.set(i);
                    }
                    i++;
                    i2++;
                } else {
                    if (this.keys[i2] >= jArr[i]) {
                        throw new AssertionError("Key domain of new immutable vector must be subset of original domain");
                    }
                    i2++;
                }
            }
        }
        Reference2ObjectArrayMap reference2ObjectArrayMap = new Reference2ObjectArrayMap(this.channelMap.size());
        for (Map.Entry<Symbol, MutableSparseVector> entry : this.channelMap.entrySet()) {
            reference2ObjectArrayMap.put(entry.getKey(), entry.getValue().immutable(z, jArr));
        }
        Reference2ObjectArrayMap reference2ObjectArrayMap2 = new Reference2ObjectArrayMap();
        for (Map.Entry<TypedSymbol<?>, TypedSideChannel<?>> entry2 : this.typedChannelMap.entrySet()) {
            if (jArr == this.keys) {
                reference2ObjectArrayMap2.put(entry2.getKey(), entry2.getValue().freeze());
            } else {
                reference2ObjectArrayMap2.put(entry2.getKey(), entry2.getValue().withDomain(new LongSortedArraySet(jArr)).freeze());
                if (z) {
                    entry2.getValue().partialFreeze();
                }
            }
        }
        ImmutableSparseVector immutableSparseVector = new ImmutableSparseVector(jArr, dArr, jArr.length, bitSet, reference2ObjectArrayMap, reference2ObjectArrayMap2);
        if (z) {
            this.values = null;
        }
        return immutableSparseVector;
    }

    public static MutableSparseVector wrap(long[] jArr, double[] dArr) {
        return wrap(jArr, dArr, jArr.length);
    }

    public static MutableSparseVector wrap(long[] jArr, double[] dArr, int i) {
        if (dArr.length < i) {
            throw new IllegalArgumentException("value array too short");
        }
        if (jArr.length < i) {
            throw new IllegalArgumentException("key array too short");
        }
        if (MoreArrays.isSorted(jArr, 0, i)) {
            return new MutableSparseVector(jArr, dArr, i);
        }
        throw new IllegalArgumentException("item array not sorted");
    }

    public static MutableSparseVector wrap(LongArrayList longArrayList, DoubleArrayList doubleArrayList) {
        return wrap(longArrayList.elements(), doubleArrayList.elements(), longArrayList.size());
    }

    public static MutableSparseVector create(long... jArr) {
        return new MutableSparseVector((Collection<Long>) new LongOpenHashSet(jArr));
    }

    public static MutableSparseVector wrapUnsorted(long[] jArr, double[] dArr) {
        it.unimi.dsi.fastutil.Arrays.quickSort(0, jArr.length, new IdComparator(jArr), new ParallelSwapper(jArr, dArr));
        return wrap(jArr, dArr);
    }

    public SparseVector removeChannel(Symbol symbol) {
        checkFrozen();
        if (hasChannel(symbol)) {
            return this.channelMap.remove(symbol);
        }
        throw new IllegalArgumentException("No such channel " + symbol.getName());
    }

    public <K> TypedSideChannel<K> removeChannel(TypedSymbol<K> typedSymbol) {
        checkFrozen();
        if (hasChannel((TypedSymbol<?>) typedSymbol)) {
            return (TypedSideChannel) this.typedChannelMap.remove(typedSymbol);
        }
        throw new IllegalArgumentException("No such channel " + typedSymbol.getName() + " with type " + typedSymbol.getType().getSimpleName());
    }

    public void removeAllChannels() {
        checkFrozen();
        this.channelMap.clear();
        this.typedChannelMap.clear();
    }

    public MutableSparseVector addChannel(Symbol symbol) {
        checkFrozen();
        if (hasChannel(symbol)) {
            throw new IllegalArgumentException("Channel " + symbol.getName() + " already exists");
        }
        MutableSparseVector mutableSparseVector = new MutableSparseVector((Collection<Long>) keyDomain());
        this.channelMap.put(symbol, mutableSparseVector);
        return mutableSparseVector;
    }

    public <K> TypedSideChannel<K> addChannel(TypedSymbol<K> typedSymbol) {
        checkFrozen();
        if (hasChannel((TypedSymbol<?>) typedSymbol)) {
            throw new IllegalArgumentException("Channel " + typedSymbol.getName() + " with type " + typedSymbol.getType().getSimpleName() + " already exists");
        }
        TypedSideChannel<K> typedSideChannel = new TypedSideChannel<>(keyDomain().toLongArray());
        this.typedChannelMap.put(typedSymbol, typedSideChannel);
        return typedSideChannel;
    }

    public MutableSparseVector alwaysAddChannel(Symbol symbol) {
        MutableSparseVector mutableSparseVector = this.channelMap.get(symbol);
        if (mutableSparseVector == null) {
            mutableSparseVector = addChannel(symbol);
        }
        return mutableSparseVector;
    }

    public <K> TypedSideChannel<K> alwaysAddChannel(TypedSymbol<K> typedSymbol) {
        if (!hasChannel((TypedSymbol<?>) typedSymbol)) {
            addChannel(typedSymbol);
        }
        return (TypedSideChannel) this.typedChannelMap.get(typedSymbol);
    }

    public MutableSparseVector addChannel(Symbol symbol, SparseVector sparseVector) {
        checkFrozen();
        if (hasChannel(symbol)) {
            throw new IllegalArgumentException("Channel " + symbol.getName() + " already exists");
        }
        if (!keyDomain().containsAll(sparseVector.keyDomain())) {
            throw new IllegalArgumentException("The channel you are trying to add to this vector has an incompatible key domain.");
        }
        MutableSparseVector mutableCopy = sparseVector.mutableCopy();
        this.channelMap.put(symbol, mutableCopy);
        return mutableCopy;
    }

    public <K> TypedSideChannel<K> addChannel(TypedSymbol<K> typedSymbol, TypedSideChannel<K> typedSideChannel) {
        checkFrozen();
        if (hasChannel((TypedSymbol<?>) typedSymbol)) {
            throw new IllegalArgumentException("Channel " + typedSymbol.getName() + " with the type " + typedSymbol.getType().getSimpleName() + " already exists");
        }
        if (!keyDomain().containsAll(typedSideChannel.keyDomain())) {
            throw new IllegalArgumentException("The channel you are trying to add to this vector has an incompatible key domain.");
        }
        TypedSideChannel<K> mutableCopy = typedSideChannel.mutableCopy();
        this.typedChannelMap.put(typedSymbol, mutableCopy);
        return mutableCopy;
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public boolean hasChannel(Symbol symbol) {
        return this.channelMap.containsKey(symbol);
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public boolean hasChannel(TypedSymbol<?> typedSymbol) {
        return this.typedChannelMap.containsKey(typedSymbol);
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public MutableSparseVector channel(Symbol symbol) {
        checkFrozen();
        if (hasChannel(symbol)) {
            return this.channelMap.get(symbol);
        }
        throw new IllegalArgumentException("No existing channel under name " + symbol.getName());
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    /* renamed from: channel, reason: merged with bridge method [inline-methods] */
    public <K> TypedSideChannel<K> mo25channel(TypedSymbol<K> typedSymbol) {
        checkFrozen();
        if (hasChannel((TypedSymbol<?>) typedSymbol)) {
            return (TypedSideChannel) this.typedChannelMap.get(typedSymbol);
        }
        throw new IllegalArgumentException("No existing channel under name " + typedSymbol.getName() + "with the type " + typedSymbol.getType().getSimpleName());
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public Set<Symbol> getChannels() {
        return Collections.unmodifiableSet(this.channelMap.keySet());
    }

    @Override // org.grouplens.lenskit.vectors.SparseVector
    public Set<TypedSymbol<?>> getTypedChannels() {
        return Collections.unmodifiableSet(this.typedChannelMap.keySet());
    }

    static {
        $assertionsDisabled = !MutableSparseVector.class.desiredAssertionStatus();
    }
}
