package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/slm/ElasticNetCDTrainer.class */
public class ElasticNetCDTrainer implements SparseTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(ElasticNetCDTrainer.class.getName());

    @Config(mandatory = true, description = "Overall regularisation penalty.")
    private double alpha;

    @Config(mandatory = true, description = "Ratio of l1 to l2 parameters.")
    private double l1Ratio;

    @Config(description = "Tolerance on the error.")
    private double tolerance;

    @Config(description = "Maximium number of iterations to run.")
    private int maxIterations;

    @Config(description = "Randomises the order in which the features are probed.")
    private boolean randomise;

    @Config(description = "The seed for the RNG.")
    private long seed;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/regression/slm/ElasticNetCDTrainer$ElasticNetState.class */
    public static class ElasticNetState {
        final SparseVector[] columns;
        final int numFeatures;
        final int numExamples;
        final int[] featureIndices;
        final double[] featureMeans;
        final double[] columnNorms;
        final double l1Penalty;
        final double l2Penalty;
        final boolean center;

        public ElasticNetState(SparseVector[] sparseVectorArr, int[] iArr, double[] dArr, double[] dArr2, double d, double d2, boolean z) {
            this.columns = sparseVectorArr;
            this.numFeatures = sparseVectorArr.length;
            this.numExamples = sparseVectorArr[0].size();
            this.featureIndices = iArr;
            this.featureMeans = dArr;
            this.columnNorms = dArr2;
            this.l1Penalty = d;
            this.l2Penalty = d2;
            this.center = z;
        }
    }

    private ElasticNetCDTrainer() {
        this.tolerance = 1.0E-4d;
        this.maxIterations = 500;
        this.randomise = false;
        this.seed = 12345L;
    }

    public ElasticNetCDTrainer(double d, double d2) {
        this(d, d2, 1.0E-4d, 500, false, 12345L);
    }

    public ElasticNetCDTrainer(double d, double d2, long j) {
        this(d, d2, 1.0E-4d, 500, true, j);
    }

    public ElasticNetCDTrainer(double d, double d2, double d3, int i, boolean z, long j) {
        this.tolerance = 1.0E-4d;
        this.maxIterations = 500;
        this.randomise = false;
        this.seed = 12345L;
        this.alpha = d;
        this.l1Ratio = d2;
        this.tolerance = d3;
        this.maxIterations = i;
        this.randomise = z;
        this.seed = j;
        postConfig();
    }

    public synchronized void postConfig() {
        if (this.l1Ratio < 1.0E-12d || this.l1Ratio > 1.000000000001d) {
            throw new PropertyException("l1Ratio", "L1 Ratio must be between 0 and 1. Found value " + this.l1Ratio);
        }
        this.rng = new SplittableRandom(this.seed);
    }

    public SparseModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public SparseModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m3getProvenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m3getProvenance = m3getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = featureIDMap.size();
        int size2 = outputIDInfo.size();
        int size3 = dataset.size();
        SparseVector[] transpose = SparseVector.transpose(dataset, featureIDMap);
        String[] strArr = new String[size2];
        DenseVector[] denseVectorArr = new DenseVector[size2];
        for (int i2 = 0; i2 < size2; i2++) {
            strArr[i2] = outputIDInfo.getOutput(i2).getNames()[0];
            denseVectorArr[i2] = new DenseVector(size3);
        }
        int i3 = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Example) it.next()).getOutput().iterator();
            while (it2.hasNext()) {
                Regressor.DimensionTuple dimensionTuple = (Regressor.DimensionTuple) it2.next();
                denseVectorArr[outputIDInfo.getID(dimensionTuple)].set(i3, dimensionTuple.getValue());
            }
            i3++;
        }
        double d = this.alpha * this.l1Ratio * size3;
        double d2 = this.alpha * (1.0d - this.l1Ratio) * size3;
        double[] calculateMeans = calculateMeans(transpose);
        double[] dArr = new double[transpose.length];
        Arrays.fill(dArr, 1.0d);
        boolean z = false;
        int i4 = 0;
        while (true) {
            if (i4 >= size) {
                break;
            }
            if (Math.abs(calculateMeans[i4]) > 1.0E-12d) {
                z = true;
                break;
            }
            i4++;
        }
        double[] dArr2 = new double[size];
        int[] iArr = new int[size];
        for (int i5 = 0; i5 < size; i5++) {
            iArr[i5] = i5;
            double d3 = 0.0d;
            VectorIterator it3 = transpose[i5].iterator();
            while (it3.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it3.next();
                d3 += (vectorTuple.value - calculateMeans[i5]) * (vectorTuple.value - calculateMeans[i5]);
            }
            dArr2[i5] = d3 + ((size3 - transpose[i5].numActiveElements()) * calculateMeans[i5] * calculateMeans[i5]);
        }
        ElasticNetState elasticNetState = new ElasticNetState(transpose, iArr, calculateMeans, dArr2, d, d2, z);
        SparseVector[] sparseVectorArr = new SparseVector[size2];
        double[] dArr3 = new double[size2];
        for (int i6 = 0; i6 < strArr.length; i6++) {
            sparseVectorArr[i6] = trainSingleDimension(denseVectorArr[i6], elasticNetState, split.split());
            dArr3[i6] = denseVectorArr[i6].sum() / size3;
        }
        double[] dArr4 = new double[size2];
        Arrays.fill(dArr4, 1.0d);
        return new SparseLinearModel("elastic-net-model", strArr, new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m3getProvenance, map), featureIDMap, outputIDInfo, sparseVectorArr, DenseVector.createDenseVector(calculateMeans), DenseVector.createDenseVector(dArr), dArr3, dArr4, false);
    }

    private SparseVector trainSingleDimension(DenseVector denseVector, ElasticNetState elasticNetState, SplittableRandom splittableRandom) {
        double d;
        double d2;
        int i = elasticNetState.numFeatures;
        int i2 = elasticNetState.numExamples;
        DenseVector copy = denseVector.copy();
        DenseVector denseVector2 = new DenseVector(i);
        double twoNorm = denseVector.twoNorm();
        double d3 = this.tolerance * twoNorm * twoNorm;
        double[] dArr = new double[i];
        double[] dArr2 = new double[i];
        int i3 = 0;
        while (true) {
            if (i3 >= this.maxIterations) {
                break;
            }
            double d4 = 0.0d;
            double d5 = 0.0d;
            if (this.randomise) {
                Util.randpermInPlace(elasticNetState.featureIndices, splittableRandom);
            }
            for (int i4 = 0; i4 < i; i4++) {
                int i5 = elasticNetState.featureIndices[i4];
                if (Math.abs(elasticNetState.columnNorms[i5]) >= 1.0E-12d) {
                    double d6 = denseVector2.get(i5);
                    if (d6 != 0.0d) {
                        VectorIterator it = elasticNetState.columns[i5].iterator();
                        while (it.hasNext()) {
                            VectorTuple vectorTuple = (VectorTuple) it.next();
                            copy.set(vectorTuple.index, copy.get(vectorTuple.index) + (vectorTuple.value * d6));
                        }
                        if (elasticNetState.center) {
                            for (int i6 = 0; i6 < i2; i6++) {
                                copy.set(i6, copy.get(i6) - (elasticNetState.featureMeans[i5] * d6));
                            }
                        }
                    }
                    double dot = copy.dot(elasticNetState.columns[i5]);
                    if (elasticNetState.center) {
                        dot -= copy.sum() * elasticNetState.featureMeans[i5];
                    }
                    double signum = (Math.signum(dot) * Math.max(Math.abs(dot) - elasticNetState.l1Penalty, 0.0d)) / (elasticNetState.columnNorms[i5] + elasticNetState.l2Penalty);
                    denseVector2.set(i5, signum);
                    if (signum != 0.0d) {
                        VectorIterator it2 = elasticNetState.columns[i5].iterator();
                        while (it2.hasNext()) {
                            VectorTuple vectorTuple2 = (VectorTuple) it2.next();
                            copy.set(vectorTuple2.index, copy.get(vectorTuple2.index) - (vectorTuple2.value * signum));
                        }
                        if (elasticNetState.center) {
                            for (int i7 = 0; i7 < i2; i7++) {
                                copy.set(i7, copy.get(i7) + (elasticNetState.featureMeans[i5] * signum));
                            }
                        }
                    }
                    double abs = Math.abs(signum - d6);
                    if (abs > d5) {
                        d5 = abs;
                    }
                    double abs2 = Math.abs(signum);
                    if (abs2 > d4) {
                        d4 = abs2;
                    }
                }
            }
            if (d4 < 1.0E-12d || d5 / d4 < this.tolerance || i3 == this.maxIterations - 1) {
                double sum = copy.sum();
                double d7 = 0.0d;
                for (int i8 = 0; i8 < i; i8++) {
                    dArr[i8] = copy.dot(elasticNetState.columns[i8]);
                    if (elasticNetState.center) {
                        int i9 = i8;
                        dArr[i9] = dArr[i9] - (elasticNetState.featureMeans[i8] * sum);
                    }
                    dArr2[i8] = dArr[i8] - (elasticNetState.l2Penalty * denseVector2.get(i8));
                    double abs3 = Math.abs(dArr2[i8]);
                    if (abs3 > d7) {
                        d7 = abs3;
                    }
                }
                double twoNorm2 = copy.twoNorm();
                double d8 = twoNorm2 * twoNorm2;
                double twoNorm3 = denseVector2.twoNorm();
                double d9 = twoNorm3 * twoNorm3;
                double oneNorm = denseVector2.oneNorm();
                if (d7 > elasticNetState.l1Penalty) {
                    d = elasticNetState.l1Penalty / d7;
                    d2 = 0.5d * (d8 + (d8 * d * d));
                } else {
                    d = 1.0d;
                    d2 = d8;
                }
                double dot2 = d2 + ((elasticNetState.l1Penalty * oneNorm) - (d * copy.dot(denseVector))) + (0.5d * elasticNetState.l2Penalty * (1.0d + (d * d)) * d9);
                if (dot2 < d3) {
                    logger.log(Level.INFO, "Iteration: " + i3 + ", duality gap = " + dot2 + ", tolerance = " + d3);
                    break;
                }
            }
            i3++;
        }
        return denseVector2.sparsify();
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    public String toString() {
        return "ElasticNetCDTrainer(alpha=" + this.alpha + ",l1Ratio=" + this.l1Ratio + ",tolerance=" + this.tolerance + ",maxIterations=" + this.maxIterations + ",randomise=" + this.randomise + ",seed=" + this.seed + ")";
    }

    private static double[] calculateMeans(SGDVector[] sGDVectorArr) {
        double[] dArr = new double[sGDVectorArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = sGDVectorArr[i].sum() / sGDVectorArr[i].size();
        }
        return dArr;
    }

    private static double[] calculateVariances(SGDVector[] sGDVectorArr, double[] dArr) {
        double[] dArr2 = new double[sGDVectorArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = sGDVectorArr[i].variance(dArr[i]);
        }
        return dArr2;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m3getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m1train(Dataset dataset, Map map, int i) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m2train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }
}
