package org.tribuo.regression.slm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Logger;
import org.apache.commons.math3.linear.RealVector;
import org.tribuo.regression.slm.SLMTrainer;

/* loaded from: input_file:org/tribuo/regression/slm/LARSTrainer.class */
public class LARSTrainer extends SLMTrainer {
    private static final Logger logger = Logger.getLogger(LARSTrainer.class.getName());

    public LARSTrainer(int i) {
        super(true, i);
    }

    public LARSTrainer() {
        this(-1);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tribuo.regression.slm.SLMTrainer
    public RealVector newWeights(SLMTrainer.SLMState sLMState) {
        if (sLMState.last.booleanValue()) {
            return super.newWeights(sLMState);
        }
        RealVector ordinaryLeastSquares = SLMTrainer.ordinaryLeastSquares(sLMState.xpi, sLMState.r);
        if (ordinaryLeastSquares == null) {
            return null;
        }
        RealVector unpack = sLMState.unpack(ordinaryLeastSquares);
        ArrayList arrayList = new ArrayList();
        double sumInverted = SLMTrainer.sumInverted(sLMState.xpi);
        double d = sLMState.C;
        RealVector a = SLMTrainer.getA(sLMState.X, sLMState.xpi, SLMTrainer.getwa(sLMState.xpi, sumInverted));
        for (int i = 0; i < sLMState.numFeatures; i++) {
            if (!sLMState.activeSet.contains(Integer.valueOf(i))) {
                double entry = sLMState.corr.getEntry(i);
                double entry2 = a.getEntry(i);
                double d2 = (d - entry) / (sumInverted - entry2);
                double d3 = (d + entry) / (sumInverted + entry2);
                if (d2 >= 0.0d) {
                    arrayList.add(Double.valueOf(d2));
                }
                if (d3 >= 0.0d) {
                    arrayList.add(Double.valueOf(d3));
                }
            }
        }
        return sLMState.beta.add(unpack.mapMultiplyToSelf(((Double) Collections.min(arrayList)).doubleValue()));
    }

    @Override // org.tribuo.regression.slm.SLMTrainer
    public String toString() {
        return "LARSTrainer(maxNumFeatures=" + this.maxNumFeatures + ")";
    }
}
