package org.tensorflow.ndarray.impl.sparse;

import java.nio.ReadOnlyBufferException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import java.util.stream.LongStream;
import org.tensorflow.ndarray.IllegalRankException;
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArraySequence;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.impl.AbstractNdArray;
import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray;
import org.tensorflow.ndarray.impl.dimension.Dimension;
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
import org.tensorflow.ndarray.impl.sequence.SingleElementSequence;
import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence;
import org.tensorflow.ndarray.index.Index;

/* loaded from: input_file:org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.class */
public abstract class AbstractSparseNdArray<T, U extends NdArray<T>> extends AbstractNdArray<T, U> implements org.tensorflow.ndarray.SparseNdArray<T, U> {
    private LongNdArray indices;
    private U values;
    private T defaultValue;
    private U defaultArray;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSparseNdArray(LongNdArray longNdArray, U u, T t, DimensionalSpace dimensionalSpace) {
        super(dimensionalSpace);
        this.indices = longNdArray;
        this.values = u;
        setDefaultValue(t);
        if (this.indices.shape().get(0) != this.values.shape().get(0)) {
            throw new IllegalArgumentException(String.format("The number of rows in indices (%d) does not  match the number of elements in values(%d).", Long.valueOf(this.indices.shape().get(0)), Long.valueOf(this.values.shape().get(0))));
        }
        if (this.indices.shape().get(1) != shape().numDimensions()) {
            throw new IllegalArgumentException(String.format("The number of columns in indices (%d) does not  match the number of dimensions in shape (%d).", Long.valueOf(this.indices.shape().get(1)), Long.valueOf(shape().get(0))));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSparseNdArray(T t, DimensionalSpace dimensionalSpace) {
        super(dimensionalSpace);
        setDefaultValue(t);
    }

    @Override // org.tensorflow.ndarray.NdArray
    public NdArraySequence<U> elements(int i) {
        if (i >= shape().numDimensions()) {
            throw new IllegalArgumentException("Cannot iterate elements in dimension '" + i + "' of array with shape " + shape());
        }
        return (rank() != 0 || i >= 0) ? new SlicingElementSequence(this, i, dimensions().from(i + 1)) : new SingleElementSequence(this);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long[] toCoordinates(DimensionalSpace dimensionalSpace, long j) {
        long[] jArr = new long[dimensionalSpace.numDimensions()];
        long j2 = j;
        for (int i = 0; i < dimensionalSpace.numDimensions(); i++) {
            Dimension dimension = dimensionalSpace.get(i);
            jArr[i] = j2 / dimension.elementSize();
            j2 %= dimension.elementSize();
        }
        return jArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long[] getIndicesCoordinates(LongNdArray longNdArray) {
        long[] jArr = new long[(int) longNdArray.size()];
        for (int i = 0; i < longNdArray.size(); i++) {
            jArr[i] = longNdArray.getLong(i);
        }
        return jArr;
    }

    public abstract U toDense();

    @Override // org.tensorflow.ndarray.NdArray
    /* renamed from: slice */
    public NdArray<T> slice2(Index... indexArr) {
        if (indexArr == null) {
            throw new IllegalArgumentException("Slicing requires at least one index");
        }
        RelativeDimensionalSpace mapTo = dimensions().mapTo(indexArr);
        return slice(mapTo.position(), mapTo);
    }

    @Override // org.tensorflow.ndarray.NdArray
    /* renamed from: get */
    public NdArray<T> get2(long... jArr) {
        return slice(positionOf(jArr, false), dimensions().from(jArr.length));
    }

    @Override // org.tensorflow.ndarray.NdArray
    public T getObject(long... jArr) {
        if (jArr.length != shape().numDimensions()) {
            throw new IllegalRankException(String.format("Length of coordinates (%s)%s does not match the rank %d", Integer.valueOf(jArr.length), Arrays.toString(jArr), Integer.valueOf(shape().numDimensions())));
        }
        long locateIndex = locateIndex(jArr);
        return locateIndex >= 0 ? (T) getValues().getObject(locateIndex) : this.defaultValue;
    }

    @Override // org.tensorflow.ndarray.NdArray
    public NdArray<T> setObject(T t, long... jArr) {
        throw new ReadOnlyBufferException();
    }

    @Override // org.tensorflow.ndarray.NdArray
    /* renamed from: set */
    public NdArray<T> set2(NdArray<T> ndArray, long... jArr) {
        throw new ReadOnlyBufferException();
    }

    public abstract U createValues(Shape shape);

    @Override // org.tensorflow.ndarray.NdArray
    /* renamed from: copyTo */
    public NdArray<T> copyTo2(NdArray<T> ndArray) {
        if (ndArray instanceof AbstractSparseNdArray) {
            AbstractSparseNdArray abstractSparseNdArray = (AbstractSparseNdArray) ndArray;
            LongNdArray ofLongs = NdArrays.ofLongs(this.indices.shape());
            this.indices.copyTo2((NdArray<Long>) ofLongs);
            U createValues = createValues(this.values.shape());
            this.values.copyTo2(createValues);
            abstractSparseNdArray.setIndices(ofLongs);
            abstractSparseNdArray.setValues(createValues);
        } else {
            toDense().copyTo2(ndArray);
        }
        return this;
    }

    protected long positionOf(long[] jArr, boolean z) {
        if (jArr == null || jArr.length == 0) {
            return 0L;
        }
        Validator.coordinates(this.dimensions, jArr, z);
        return this.dimensions.positionOf(jArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.tensorflow.ndarray.impl.AbstractNdArray
    public void slowCopyTo(NdArray<T> ndArray) {
        if (!(ndArray instanceof AbstractDenseNdArray)) {
            if (!(ndArray instanceof AbstractSparseNdArray)) {
                super.slowCopyTo(ndArray);
                return;
            }
            AbstractSparseNdArray abstractSparseNdArray = (AbstractSparseNdArray) ndArray;
            this.indices.copyTo2((NdArray<Long>) abstractSparseNdArray.getIndices());
            this.values.copyTo2(abstractSparseNdArray.values);
            return;
        }
        AbstractDenseNdArray abstractDenseNdArray = (AbstractDenseNdArray) ndArray;
        long j = 0;
        Iterator<T> it = scalars().iterator();
        while (it.hasNext()) {
            long j2 = j;
            j = j2 + 1;
            abstractDenseNdArray.setObject(((NdArray) it.next()).getObject(new long[0]), j2);
        }
    }

    @Override // org.tensorflow.ndarray.SparseNdArray
    public LongNdArray getIndices() {
        return this.indices;
    }

    public void setIndices(LongNdArray longNdArray) {
        this.indices = longNdArray;
    }

    @Override // org.tensorflow.ndarray.SparseNdArray
    public U getValues() {
        return this.values;
    }

    public void setValues(U u) {
        this.values = u;
    }

    protected long locateIndex(long[] jArr) {
        return binarySearch(this.indices.shape().get(0), NdArrays.vectorOf(jArr));
    }

    @Override // org.tensorflow.ndarray.impl.AbstractNdArray
    public int hashCode() {
        return dimensions().isSegmented() ? slowHashCode() : (31 * ((31 * ((31 * 1) + this.indices.hashCode())) + this.values.hashCode())) + shape().hashCode();
    }

    @Override // org.tensorflow.ndarray.impl.AbstractNdArray, org.tensorflow.ndarray.NdArray
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof AbstractSparseNdArray)) {
            return super.equals(obj);
        }
        AbstractSparseNdArray abstractSparseNdArray = (AbstractSparseNdArray) obj;
        if (shape().equals(abstractSparseNdArray.shape()) && this.indices.equals(abstractSparseNdArray.indices)) {
            return this.values.equals(abstractSparseNdArray.values);
        }
        return false;
    }

    public String toString() {
        long size = this.values == null ? 0L : this.values.size();
        String obj = this.defaultValue == null ? "<null>" : this.defaultValue instanceof Number ? this.defaultValue.toString() : "'" + this.defaultValue + "'";
        String simpleName = getClass().getSimpleName();
        shape();
        return simpleName + "(defaultValue=" + obj + ", numElements=" + size + ", shape=" + simpleName + ")";
    }

    /* JADX WARN: Type inference failed for: r0v13, types: [org.tensorflow.ndarray.LongNdArray] */
    private long binarySearch(long j, LongNdArray longNdArray) {
        long j2 = 0;
        long j3 = j - 1;
        while (j2 <= j3) {
            long j4 = (j2 + j3) >>> 1;
            int compareCoordinates = compareCoordinates(this.indices.get2(j4), longNdArray);
            if (compareCoordinates < 0) {
                j2 = j4 + 1;
            } else {
                if (compareCoordinates <= 0) {
                    return j4;
                }
                j3 = j4 - 1;
            }
        }
        return -(j2 + 1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public AbstractSparseNdArray<T, U> sortIndicesAndValues() {
        ArrayList arrayList = new ArrayList();
        LongStream range = LongStream.range(0L, this.values.size());
        Objects.requireNonNull(arrayList);
        range.forEach((v1) -> {
            r1.add(v1);
        });
        arrayList.sort((l, l2) -> {
            return compareCoordinates(this.indices.get2(l.longValue()), this.indices.get2(l2.longValue()));
        });
        LongNdArray ofLongs = NdArrays.ofLongs(this.indices.shape());
        U createValues = createValues(this.values.shape());
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= arrayList.size()) {
                this.indices = ofLongs;
                this.values = createValues;
                return this;
            }
            long longValue = ((Long) arrayList.get((int) j2)).longValue();
            ofLongs.set2(this.indices.get2(longValue), j2);
            createValues.setObject(this.values.getObject(longValue), j2);
            j = j2 + 1;
        }
    }

    private int compareCoordinates(LongNdArray longNdArray, LongNdArray longNdArray2) {
        int size = (int) (longNdArray.size() - longNdArray2.size());
        if (size != 0) {
            return size;
        }
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= longNdArray.size()) {
                return 0;
            }
            int i = (int) (longNdArray.getLong(j2) - longNdArray2.getLong(j2));
            if (i != 0) {
                return i;
            }
            j = j2 + 1;
        }
    }

    public T getDefaultValue() {
        return this.defaultValue;
    }

    public void setDefaultValue(T t) {
        this.defaultValue = t;
        this.defaultArray = null;
    }

    public abstract U createDefaultArray();

    public U getDefaultArray() {
        if (this.defaultArray == null) {
            this.defaultArray = createDefaultArray();
        }
        return this.defaultArray;
    }
}
