package org.tribuo.math.optimisers.util;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DenseTensorProto;
import org.tribuo.math.protos.ShrinkingDenseTensorProto;
import org.tribuo.math.protos.TensorProto;

/* loaded from: input_file:org/tribuo/math/optimisers/util/ShrinkingVector.class */
public class ShrinkingVector extends DenseVector implements ShrinkingTensor {
    private final double baseRate;
    private final boolean scaleShrinking;
    private final double lambdaSqrt;
    private final boolean reproject;
    private double squaredTwoNorm;
    private int iteration;
    private double multiplier;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/optimisers/util/ShrinkingVector$ShrinkingVectorIterator.class */
    public static class ShrinkingVectorIterator implements VectorIterator {
        private final ShrinkingVector vector;
        private final VectorTuple tuple = new VectorTuple();
        private int index = 0;

        public ShrinkingVectorIterator(ShrinkingVector shrinkingVector) {
            this.vector = shrinkingVector;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.index < this.vector.size();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public VectorTuple next() {
            this.tuple.index = this.index;
            this.tuple.value = this.vector.get(this.index);
            this.index++;
            return this.tuple;
        }

        @Override // org.tribuo.math.la.VectorIterator
        public VectorTuple getReference() {
            return this.tuple;
        }
    }

    public ShrinkingVector(DenseVector denseVector, double d, boolean z) {
        super(denseVector);
        this.baseRate = d;
        this.scaleShrinking = z;
        this.lambdaSqrt = 0.0d;
        this.reproject = false;
        this.iteration = 1;
        this.multiplier = 1.0d;
    }

    public ShrinkingVector(DenseVector denseVector, double d, double d2) {
        super(denseVector);
        this.baseRate = d;
        this.scaleShrinking = true;
        this.lambdaSqrt = Math.sqrt(d2);
        this.reproject = true;
        this.squaredTwoNorm = 0.0d;
        this.iteration = 1;
        this.multiplier = 1.0d;
    }

    private ShrinkingVector(double[] dArr, double d, boolean z, double d2, boolean z2, double d3, int i, double d4) {
        super(dArr);
        this.baseRate = d;
        this.scaleShrinking = z;
        this.lambdaSqrt = d2;
        this.reproject = z2;
        this.squaredTwoNorm = d3;
        this.iteration = i;
        this.multiplier = d4;
    }

    public static ShrinkingVector deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        ShrinkingDenseTensorProto unpack = any.unpack(ShrinkingDenseTensorProto.class);
        return new ShrinkingVector(DenseVector.unpackProto(unpack.getData()).toArray(), unpack.getBaseRate(), unpack.getScaleShrinking(), unpack.getLambdaSqrt(), unpack.getReproject(), unpack.getSquaredTwoNorm(), unpack.getIteration(), unpack.getMultiplier());
    }

    @Override // org.tribuo.math.la.DenseVector
    /* renamed from: serialize */
    public TensorProto mo20serialize() {
        TensorProto.Builder newBuilder = TensorProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(ShrinkingVector.class.getName());
        ShrinkingDenseTensorProto.Builder newBuilder2 = ShrinkingDenseTensorProto.newBuilder();
        DenseTensorProto.Builder newBuilder3 = DenseTensorProto.newBuilder();
        newBuilder3.addDimensions(this.elements.length);
        ByteBuffer order = ByteBuffer.allocate(this.elements.length * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
        asDoubleBuffer.put(this.elements);
        asDoubleBuffer.rewind();
        newBuilder3.setValues(ByteString.copyFrom(order));
        newBuilder2.setData(newBuilder3.m178build());
        newBuilder2.setBaseRate(this.baseRate);
        newBuilder2.setLambdaSqrt(this.lambdaSqrt);
        newBuilder2.setScaleShrinking(this.scaleShrinking);
        newBuilder2.setReproject(this.reproject);
        newBuilder2.setSquaredTwoNorm(this.squaredTwoNorm);
        newBuilder2.setIteration(this.iteration);
        newBuilder2.setMultiplier(this.multiplier);
        newBuilder.setSerializedData(Any.pack(newBuilder2.m695build()));
        return newBuilder.m836build();
    }

    @Override // org.tribuo.math.optimisers.util.ShrinkingTensor
    public DenseVector convertToDense() {
        return DenseVector.createDenseVector(toArray());
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public ShrinkingVector copy() {
        return new ShrinkingVector(Arrays.copyOf(this.elements, this.elements.length), this.baseRate, this.scaleShrinking, this.lambdaSqrt, this.reproject, this.squaredTwoNorm, this.iteration, this.multiplier);
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double[] toArray() {
        double[] dArr = new double[this.elements.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = get(i);
        }
        return dArr;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double get(int i) {
        return this.elements[i] * this.multiplier;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double sum() {
        double d = 0.0d;
        for (int i = 0; i < this.elements.length; i++) {
            d += get(i);
        }
        return d;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.Tensor
    public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        scaleInPlace(this.scaleShrinking ? 1.0d - (this.baseRate / this.iteration) : 1.0d - this.baseRate);
        for (VectorTuple vectorTuple : (SGDVector) tensor) {
            double applyAsDouble = doubleUnaryOperator.applyAsDouble(vectorTuple.value);
            double d = this.elements[vectorTuple.index] * this.multiplier;
            double d2 = d + applyAsDouble;
            this.squaredTwoNorm -= d * d;
            this.squaredTwoNorm += d2 * d2;
            this.elements[vectorTuple.index] = d2 / this.multiplier;
        }
        if (this.reproject) {
            double twoNorm = (1.0d / this.lambdaSqrt) / twoNorm();
            if (twoNorm < 1.0d) {
                scaleInPlace(twoNorm);
            }
        }
        this.iteration++;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public int indexOfMax() {
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.elements.length; i2++) {
            double d2 = get(i2);
            if (d2 > d) {
                i = i2;
                d = d2;
            }
        }
        return i;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double dot(SGDVector sGDVector) {
        double d = 0.0d;
        for (VectorTuple vectorTuple : sGDVector) {
            d += get(vectorTuple.index) * vectorTuple.value;
        }
        return d;
    }

    @Override // org.tribuo.math.la.Tensor
    public void scaleInPlace(double d) {
        this.multiplier *= d;
        if (Math.abs(this.multiplier) < 1.0E-6d) {
            reifyMultiplier();
        }
    }

    private void reifyMultiplier() {
        for (int i = 0; i < this.elements.length; i++) {
            double[] dArr = this.elements;
            int i2 = i;
            dArr[i2] = dArr[i2] * this.multiplier;
        }
        this.multiplier = 1.0d;
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public double twoNorm() {
        return Math.sqrt(this.squaredTwoNorm);
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double maxValue() {
        return this.multiplier * super.maxValue();
    }

    @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
    public double minValue() {
        return this.multiplier * super.minValue();
    }

    @Override // org.tribuo.math.la.DenseVector, java.lang.Iterable
    /* renamed from: iterator */
    public Iterator<VectorTuple> iterator2() {
        return new ShrinkingVectorIterator(this);
    }
}
