package edu.iu.dsc.tws.examples.ml.svm.tset;

import edu.iu.dsc.tws.api.tset.TSetContext;
import edu.iu.dsc.tws.api.tset.fn.BaseTFunction;
import edu.iu.dsc.tws.api.tset.fn.MapFunc;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.NullDataSetException;
import edu.iu.dsc.tws.examples.ml.svm.sgd.pegasos.PegasosSgdSvm;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/tset/SvmTrainMap.class */
public class SvmTrainMap extends BaseTFunction<double[][], double[]> implements MapFunc<double[][], double[]> {
    private static final Logger LOG = Logger.getLogger(SvmTrainMap.class.getName());
    private double[] w;
    private BinaryBatchModel binaryBatchModel;
    private SVMJobParameters svmJobParameters;
    private PegasosSgdSvm pegasosSgdSvm;
    private boolean debug = false;

    public SvmTrainMap(BinaryBatchModel binaryBatchModel, SVMJobParameters sVMJobParameters) {
        this.binaryBatchModel = binaryBatchModel;
        this.svmJobParameters = sVMJobParameters;
    }

    public void prepare(TSetContext tSetContext) {
        super.prepare(tSetContext);
        this.w = this.binaryBatchModel.getW();
    }

    public double[] map(double[][] dArr) {
        if (this.debug) {
            LOG.info(String.format("Training Dimensions [%d,%d]", Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length)));
        }
        this.binaryBatchModel = DataUtils.updateModelData(this.binaryBatchModel, dArr);
        this.binaryBatchModel.setW((double[]) getTSetContext().getInput(Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR).getConsumer().next());
        this.pegasosSgdSvm = new PegasosSgdSvm(this.binaryBatchModel.getW(), this.binaryBatchModel.getX(), this.binaryBatchModel.getY(), this.binaryBatchModel.getAlpha(), this.binaryBatchModel.getIterations(), this.binaryBatchModel.getFeatures());
        try {
            this.pegasosSgdSvm.iterativeTaskSgd(this.binaryBatchModel.getW(), this.binaryBatchModel.getX(), this.binaryBatchModel.getY());
        } catch (MatrixMultiplicationException e) {
            e.printStackTrace();
        } catch (NullDataSetException e2) {
            e2.printStackTrace();
        }
        return this.pegasosSgdSvm.getW();
    }
}
