package org.tribuo.math.la;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.math.protos.DenseTensorProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.util.MeanVarianceAccumulator;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/math/la/DenseVector.class */
public class DenseVector implements SGDVector {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final int[] shape;
    protected final double[] elements;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/la/DenseVector$DenseVectorIterator.class */
    public static class DenseVectorIterator implements VectorIterator {
        private final DenseVector vector;
        private final VectorTuple tuple = new VectorTuple();
        private int index = 0;

        public DenseVectorIterator(DenseVector denseVector) {
            this.vector = denseVector;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.index < this.vector.elements.length;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public VectorTuple next() {
            if (!hasNext()) {
                throw new NoSuchElementException("Off the end of the iterator.");
            }
            this.tuple.index = this.index;
            this.tuple.value = this.vector.elements[this.index];
            this.index++;
            return this.tuple;
        }

        @Override // org.tribuo.math.la.VectorIterator
        public VectorTuple getReference() {
            return this.tuple;
        }
    }

    public DenseVector(int i) {
        this(i, 0.0d);
    }

    public DenseVector(int i, double d) {
        this.elements = new double[i];
        Arrays.fill(this.elements, d);
        this.shape = new int[]{i};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DenseVector(double[] dArr) {
        this.elements = dArr;
        this.shape = new int[]{this.elements.length};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DenseVector(DenseVector denseVector) {
        this(denseVector.toArray());
    }

    public static DenseVector createDenseVector(double[] dArr) {
        return new DenseVector(Arrays.copyOf(dArr, dArr.length));
    }

    public static <T extends Output<T>> DenseVector createDenseVector(Example<T> example, ImmutableFeatureMap immutableFeatureMap, boolean z) {
        int size = z ? immutableFeatureMap.size() + 1 : immutableFeatureMap.size();
        double[] dArr = new double[size];
        boolean z2 = false;
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            if (id != -1) {
                dArr[id] = feature.getValue();
                z2 = true;
                if (Double.isNaN(dArr[id])) {
                    throw new IllegalArgumentException("Example contained a NaN feature, " + feature.toString());
                }
            }
        }
        if (!z2) {
            throw new IllegalArgumentException("No features in this example were found in the feature map. Example - " + example.toString());
        }
        if (z) {
            dArr[size - 1] = 1.0d;
        }
        return new DenseVector(dArr);
    }

    public static DenseVector deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        return unpackProto(any.unpack(DenseTensorProto.class));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static DenseVector unpackProto(DenseTensorProto denseTensorProto) {
        int[] primitiveInt = Util.toPrimitiveInt(denseTensorProto.getDimensionsList());
        if (primitiveInt.length != 1) {
            throw new IllegalArgumentException("Invalid proto, expected a vector, found shape " + Arrays.toString(primitiveInt));
        }
        if (primitiveInt[0] < 1) {
            throw new IllegalArgumentException("Invalid proto, shape must be positive, found " + primitiveInt[0] + " at position 0");
        }
        int product = Util.product(primitiveInt);
        DoubleBuffer asDoubleBuffer = denseTensorProto.getValues().asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer();
        if (asDoubleBuffer.remaining() != product) {
            throw new IllegalArgumentException("Invalid proto, claimed " + product + ", but only had " + asDoubleBuffer.remaining() + " values");
        }
        double[] dArr = new double[primitiveInt[0]];
        asDoubleBuffer.get(dArr);
        return new DenseVector(dArr);
    }

    @Override // 
    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public TensorProto mo20serialize() {
        TensorProto.Builder newBuilder = TensorProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(DenseVector.class.getName());
        DenseTensorProto.Builder newBuilder2 = DenseTensorProto.newBuilder();
        newBuilder2.addAllDimensions((Iterable) Arrays.stream(this.shape).boxed().collect(Collectors.toList()));
        ByteBuffer order = ByteBuffer.allocate(this.elements.length * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
        asDoubleBuffer.put(this.elements);
        asDoubleBuffer.rewind();
        newBuilder2.setValues(ByteString.copyFrom(order));
        newBuilder.setSerializedData(Any.pack(newBuilder2.m178build()));
        return newBuilder.m836build();
    }

    @Override // org.tribuo.math.la.SGDVector
    public double[] toArray() {
        return Arrays.copyOf(this.elements, this.elements.length);
    }

    @Override // org.tribuo.math.la.Tensor
    public int[] getShape() {
        return this.shape;
    }

    @Override // org.tribuo.math.la.Tensor
    public Tensor reshape(int[] iArr) {
        if (Tensor.shapeSum(iArr) != this.elements.length) {
            throw new IllegalArgumentException("Invalid shape " + Arrays.toString(iArr) + ", expected something with " + this.elements.length + " elements.");
        }
        if (iArr.length != 2) {
            if (iArr.length == 1) {
                return new DenseVector(this);
            }
            throw new IllegalArgumentException("Only supports 1 or 2 dimensional tensors.");
        }
        DenseMatrix denseMatrix = new DenseMatrix(iArr[0], iArr[1]);
        for (int i = 0; i < size(); i++) {
            denseMatrix.set(i % iArr[0], i / iArr[0], get(i));
        }
        return denseMatrix;
    }

    @Override // org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public DenseVector copy() {
        return new DenseVector(toArray());
    }

    @Override // org.tribuo.math.la.SGDVector
    public int size() {
        return this.elements.length;
    }

    @Override // org.tribuo.math.la.SGDVector
    public int numActiveElements() {
        return this.elements.length;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double reduce(double d, DoubleUnaryOperator doubleUnaryOperator, DoubleBinaryOperator doubleBinaryOperator) {
        double d2 = d;
        for (int i = 0; i < this.elements.length; i++) {
            d2 = doubleBinaryOperator.applyAsDouble(doubleUnaryOperator.applyAsDouble(get(i)), d2);
        }
        return d2;
    }

    public <T> T reduce(T t, DoubleUnaryOperator doubleUnaryOperator, BiFunction<Double, T, T> biFunction) {
        T t2 = t;
        for (int i = 0; i < this.elements.length; i++) {
            t2 = biFunction.apply(Double.valueOf(doubleUnaryOperator.applyAsDouble(get(i))), t2);
        }
        return t2;
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof SGDVector) || this.elements.length != ((SGDVector) obj).size()) {
            return false;
        }
        Iterator<VectorTuple> iterator2 = iterator2();
        Iterator<VectorTuple> it = ((SGDVector) obj).iterator();
        while (iterator2.hasNext() && it.hasNext()) {
            if (!iterator2.next().equals(it.next())) {
                return false;
            }
        }
        return (iterator2.hasNext() || it.hasNext()) ? false : true;
    }

    public int hashCode() {
        return Arrays.hashCode(this.elements);
    }

    @Override // org.tribuo.math.la.SGDVector
    public DenseVector add(SGDVector sGDVector) {
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        double[] array = toArray();
        for (VectorTuple vectorTuple : sGDVector) {
            int i = vectorTuple.index;
            array[i] = array[i] + vectorTuple.value;
        }
        return new DenseVector(array);
    }

    @Override // org.tribuo.math.la.SGDVector
    public DenseVector subtract(SGDVector sGDVector) {
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        double[] array = toArray();
        for (VectorTuple vectorTuple : sGDVector) {
            int i = vectorTuple.index;
            array[i] = array[i] - vectorTuple.value;
        }
        return new DenseVector(array);
    }

    @Override // org.tribuo.math.la.Tensor
    public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        if (!(tensor instanceof SGDVector)) {
            throw new IllegalArgumentException("Adding a non-Vector to a Vector");
        }
        SGDVector sGDVector = (SGDVector) tensor;
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            for (int i = 0; i < this.elements.length; i++) {
                double[] dArr = this.elements;
                int i2 = i;
                dArr[i2] = dArr[i2] + doubleUnaryOperator.applyAsDouble(sGDVector.get(i));
            }
            return;
        }
        for (VectorTuple vectorTuple : sGDVector) {
            double[] dArr2 = this.elements;
            int i3 = vectorTuple.index;
            dArr2[i3] = dArr2[i3] + doubleUnaryOperator.applyAsDouble(vectorTuple.value);
        }
    }

    @Override // org.tribuo.math.la.Tensor
    public void hadamardProductInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        if (!(tensor instanceof SGDVector)) {
            throw new IllegalArgumentException("Scaling a Vector by a non-Vector");
        }
        SGDVector sGDVector = (SGDVector) tensor;
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            for (int i = 0; i < this.elements.length; i++) {
                double[] dArr = this.elements;
                int i2 = i;
                dArr[i2] = dArr[i2] * doubleUnaryOperator.applyAsDouble(sGDVector.get(i));
            }
            return;
        }
        for (VectorTuple vectorTuple : sGDVector) {
            double[] dArr2 = this.elements;
            int i3 = vectorTuple.index;
            dArr2[i3] = dArr2[i3] * doubleUnaryOperator.applyAsDouble(vectorTuple.value);
        }
    }

    @Override // org.tribuo.math.la.Tensor
    public void foreachInPlace(DoubleUnaryOperator doubleUnaryOperator) {
        for (int i = 0; i < this.elements.length; i++) {
            this.elements[i] = doubleUnaryOperator.applyAsDouble(this.elements[i]);
        }
    }

    @Override // org.tribuo.math.la.SGDVector
    public void foreachIndexedInPlace(ToDoubleBiFunction<Integer, Double> toDoubleBiFunction) {
        for (int i = 0; i < this.elements.length; i++) {
            this.elements[i] = toDoubleBiFunction.applyAsDouble(Integer.valueOf(i), Double.valueOf(this.elements[i]));
        }
    }

    @Override // org.tribuo.math.la.SGDVector
    public DenseVector scale(double d) {
        DenseVector copy = copy();
        copy.scaleInPlace(d);
        return copy;
    }

    @Override // org.tribuo.math.la.SGDVector
    public void add(int i, double d) {
        double[] dArr = this.elements;
        dArr[i] = dArr[i] + d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double dot(SGDVector sGDVector) {
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't dot two vectors of different dimension, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        double d = 0.0d;
        if (sGDVector instanceof DenseVector) {
            for (int i = 0; i < this.elements.length; i++) {
                d += get(i) * sGDVector.get(i);
            }
        } else {
            for (VectorTuple vectorTuple : sGDVector) {
                d += get(vectorTuple.index) * vectorTuple.value;
            }
        }
        return d;
    }

    /* JADX WARN: Type inference failed for: r0v20, types: [double[], double[][]] */
    @Override // org.tribuo.math.la.SGDVector
    public Matrix outer(SGDVector sGDVector) {
        if (sGDVector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector) sGDVector;
            ?? r0 = new double[this.elements.length];
            for (int i = 0; i < this.elements.length; i++) {
                r0[i] = denseVector.scale(get(i)).elements;
            }
            return new DenseMatrix((double[][]) r0);
        }
        if (!(sGDVector instanceof SparseVector)) {
            throw new IllegalArgumentException("Invalid vector subclass " + sGDVector.getClass().getCanonicalName() + " for input");
        }
        SparseVector sparseVector = (SparseVector) sGDVector;
        SparseVector[] sparseVectorArr = new SparseVector[this.elements.length];
        for (int i2 = 0; i2 < this.elements.length; i2++) {
            sparseVectorArr[i2] = sparseVector.scale(get(i2));
        }
        return new DenseSparseMatrix(sparseVectorArr);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double sum() {
        double d = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            d += get(i);
        }
        return d;
    }

    public double sum(DoubleUnaryOperator doubleUnaryOperator) {
        double d = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            d += doubleUnaryOperator.applyAsDouble(get(i));
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public double twoNorm() {
        double d = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            double d2 = get(i);
            d += d2 * d2;
        }
        return Math.sqrt(d);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double oneNorm() {
        double d = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            d += Math.abs(get(i));
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double get(int i) {
        return this.elements[i];
    }

    @Override // org.tribuo.math.la.SGDVector
    public void set(int i, double d) {
        this.elements[i] = d;
    }

    public void setElements(DenseVector denseVector) {
        for (int i = 0; i < this.elements.length; i++) {
            this.elements[i] = denseVector.get(i);
        }
    }

    public void fill(double d) {
        Arrays.fill(this.elements, d);
    }

    @Override // org.tribuo.math.la.SGDVector
    public int indexOfMax() {
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.elements.length; i2++) {
            double d2 = get(i2);
            if (d2 > d) {
                i = i2;
                d = d2;
            }
        }
        return i;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double maxValue() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.elements.length; i++) {
            double d2 = get(i);
            if (d2 > d) {
                d = d2;
            }
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double minValue() {
        double d = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.elements.length; i++) {
            double d2 = get(i);
            if (d2 < d) {
                d = d2;
            }
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public void normalize(VectorNormalizer vectorNormalizer) {
        vectorNormalizer.normalizeInPlace(this.elements);
    }

    public void expNormalize(double d) {
        for (int i = 0; i < this.elements.length; i++) {
            this.elements[i] = Math.exp(this.elements[i] - d);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("DenseVector(size=");
        sb.append(this.elements.length);
        sb.append(",values=[");
        for (int i = 0; i < this.elements.length; i++) {
            sb.append(get(i));
            sb.append(",");
        }
        sb.setCharAt(sb.length() - 1, ']');
        sb.append(")");
        return sb.toString();
    }

    @Override // org.tribuo.math.la.SGDVector
    public double variance(double d) {
        double d2 = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            double d3 = get(i) - d;
            d2 += d3 * d3;
        }
        return d2;
    }

    @Override // java.lang.Iterable
    /* renamed from: iterator, reason: merged with bridge method [inline-methods] */
    public Iterator<VectorTuple> iterator2() {
        return new DenseVectorIterator(this);
    }

    public SparseVector sparsify() {
        return sparsify(1.0E-12d);
    }

    public SparseVector sparsify(double d) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.elements.length; i++) {
            double d2 = get(i);
            if (Math.abs(d2) > d) {
                arrayList.add(Integer.valueOf(i));
                arrayList2.add(Double.valueOf(d2));
            }
        }
        return new SparseVector(this.elements.length, Util.toPrimitiveInt(arrayList), Util.toPrimitiveDouble(arrayList2));
    }

    @Override // org.tribuo.math.la.SGDVector
    public double euclideanDistance(SGDVector sGDVector) {
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't measure distance of two vectors of different lengths, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            double d = 0.0d;
            for (int i = 0; i < this.elements.length; i++) {
                double d2 = get(i) - sGDVector.get(i);
                d += d2 * d2;
            }
            return Math.sqrt(d);
        }
        if (!(sGDVector instanceof SparseVector)) {
            throw new IllegalArgumentException("Unknown vector subclass " + sGDVector.getClass().getCanonicalName() + " for input");
        }
        double d3 = 0.0d;
        int i2 = 0;
        Iterator<VectorTuple> it = sGDVector.iterator();
        while (i2 < this.elements.length && it.hasNext()) {
            VectorTuple next = it.next();
            while (i2 < this.elements.length && i2 < next.index) {
                double d4 = get(i2);
                d3 += d4 * d4;
                i2++;
            }
            if (i2 == next.index) {
                double d5 = get(i2) - next.value;
                d3 += d5 * d5;
                i2++;
            }
        }
        while (i2 < this.elements.length) {
            double d6 = get(i2);
            d3 += d6 * d6;
            i2++;
        }
        return Math.sqrt(d3);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double l1Distance(SGDVector sGDVector) {
        if (sGDVector.size() != this.elements.length) {
            throw new IllegalArgumentException("Can't measure distance of two vectors of different lengths, this = " + this.elements.length + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            double d = 0.0d;
            for (int i = 0; i < this.elements.length; i++) {
                d += Math.abs(get(i) - sGDVector.get(i));
            }
            return d;
        }
        if (!(sGDVector instanceof SparseVector)) {
            throw new IllegalArgumentException("Unknown vector subclass " + sGDVector.getClass().getCanonicalName() + " for input");
        }
        double d2 = 0.0d;
        int i2 = 0;
        Iterator<VectorTuple> it = sGDVector.iterator();
        while (i2 < this.elements.length && it.hasNext()) {
            VectorTuple next = it.next();
            while (i2 < this.elements.length && i2 < next.index) {
                d2 += Math.abs(get(i2));
                i2++;
            }
            if (i2 == next.index) {
                d2 += Math.abs(get(i2) - next.value);
                i2++;
            }
        }
        while (i2 < this.elements.length) {
            d2 += Math.abs(get(i2));
            i2++;
        }
        return d2;
    }

    public MeanVarianceAccumulator meanVariance() {
        MeanVarianceAccumulator meanVarianceAccumulator = new MeanVarianceAccumulator();
        for (int i = 0; i < this.elements.length; i++) {
            meanVarianceAccumulator.observe(get(i));
        }
        return meanVarianceAccumulator;
    }
}
