package edu.iu.dsc.tws.examples.ml.svm.streamer;

import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.InputDataFormatException;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.NullDataSetException;
import edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor;
import edu.iu.dsc.tws.examples.ml.svm.sgd.pegasos.PegasosSgdSvm;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/streamer/IterativeDataStream.class */
public class IterativeDataStream extends BaseSource implements IReceptor<double[][]> {
    private static final Logger LOG = Logger.getLogger(IterativeDataStream.class.getName());
    private static final long serialVersionUID = 6672551932831677547L;
    private final double[] labels;
    private int features;
    private OperationMode operationMode;
    private boolean isDummy;
    private BinaryBatchModel binaryBatchModel;
    private DataObject<double[][]> dataPointsObject;
    private DataObject<double[]> weightVectorObject;
    private double[][] datapoints;
    private double[] weightVector;
    private double[] computedWeightVector;
    private PegasosSgdSvm pegasosSgdSvm;
    private boolean debug;

    public IterativeDataStream(OperationMode operationMode) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.computedWeightVector = null;
        this.pegasosSgdSvm = null;
        this.debug = false;
        this.operationMode = operationMode;
    }

    public IterativeDataStream(int i, OperationMode operationMode) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.computedWeightVector = null;
        this.pegasosSgdSvm = null;
        this.debug = false;
        this.features = i;
        this.operationMode = operationMode;
    }

    public IterativeDataStream(int i, OperationMode operationMode, boolean z, BinaryBatchModel binaryBatchModel) {
        this.labels = new double[]{-1.0d, 1.0d};
        this.features = 10;
        this.isDummy = false;
        this.dataPointsObject = null;
        this.weightVectorObject = null;
        this.datapoints = null;
        this.weightVector = null;
        this.computedWeightVector = null;
        this.pegasosSgdSvm = null;
        this.debug = false;
        this.features = i;
        this.operationMode = operationMode;
        this.isDummy = z;
        this.binaryBatchModel = binaryBatchModel;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor
    public void add(String str, DataObject<?> dataObject) {
        if (this.debug) {
            LOG.log(Level.INFO, String.format("Received input: %s ", str));
        }
        if (Constants.SimpleGraphConfig.INPUT_DATA.equals(str)) {
            this.dataPointsObject = dataObject;
        }
        if (Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR.equals(str)) {
            this.weightVectorObject = dataObject;
        }
    }

    public void execute() {
        if (!this.isDummy) {
            realDataStreamer();
            return;
        }
        try {
            dummyDataStreamer();
        } catch (InputDataFormatException e) {
            e.printStackTrace();
        }
    }

    public void getData() {
        this.datapoints = (double[][]) this.dataPointsObject.getPartition(this.context.taskIndex()).getConsumer().next();
        this.weightVector = (double[]) this.weightVectorObject.getPartition(this.context.taskIndex()).getConsumer().next();
        if (this.debug) {
            LOG.info(String.format("Recieved Input Data : %s ", this.datapoints.getClass().getName()));
        }
    }

    public void dummyDataStreamer() throws InputDataFormatException {
        if (this.operationMode.equals(OperationMode.STREAMING)) {
            double[] combineLabelAndData = DataUtils.combineLabelAndData(DataUtils.seedDoubleArray(this.binaryBatchModel.getFeatures()), this.labels[new Random().nextInt(2)]);
            if (combineLabelAndData.length != this.binaryBatchModel.getFeatures() + 1) {
                throw new InputDataFormatException(String.format("Input Data Format Exception : [data length : %d, feature length +1 : %d]", Integer.valueOf(combineLabelAndData.length), Integer.valueOf(this.binaryBatchModel.getFeatures() + 1)));
            }
            this.context.write(Constants.SimpleGraphConfig.DATA_EDGE, combineLabelAndData);
        }
        if (this.operationMode.equals(OperationMode.BATCH)) {
            this.context.write(Constants.SimpleGraphConfig.DATA_EDGE, DataUtils.generateDummyDataPoints(this.binaryBatchModel.getSamples(), this.binaryBatchModel.getFeatures()));
            this.context.end(Constants.SimpleGraphConfig.DATA_EDGE);
        }
    }

    public void realDataStreamer() {
        if (!this.operationMode.equals(OperationMode.BATCH)) {
            LOG.info(String.format("Real data stream got stuck!", new Object[0]));
            return;
        }
        getData();
        initializeBatchMode();
        compute();
        this.context.writeEnd(Constants.SimpleGraphConfig.REDUCE_EDGE, this.computedWeightVector);
    }

    public void compute() {
        double[][] x = this.binaryBatchModel.getX();
        try {
            this.pegasosSgdSvm.iterativeTaskSgd(this.binaryBatchModel.getW(), x, this.binaryBatchModel.getY());
        } catch (MatrixMultiplicationException e) {
            e.printStackTrace();
        } catch (NullDataSetException e2) {
            e2.printStackTrace();
        }
        this.computedWeightVector = DataUtils.average(this.pegasosSgdSvm.getW(), this.context.getParallelism());
    }

    public void initializeBatchMode() {
        initializeBinaryModel(this.datapoints);
        this.binaryBatchModel.setW(this.weightVector);
        this.pegasosSgdSvm = new PegasosSgdSvm(this.binaryBatchModel.getW(), this.binaryBatchModel.getX(), this.binaryBatchModel.getY(), this.binaryBatchModel.getAlpha(), this.binaryBatchModel.getIterations(), this.binaryBatchModel.getFeatures());
    }

    public void initializeBinaryModel(double[][] dArr) {
        if (this.binaryBatchModel == null) {
            throw new NullPointerException("Binary Batch Model is Null !!!");
        }
        if (this.debug) {
            LOG.info("Binary Batch Model Before Updated : " + this.binaryBatchModel.toString());
        }
        this.binaryBatchModel = DataUtils.updateModelData(this.binaryBatchModel, dArr);
        if (this.debug) {
            LOG.info("Binary Batch Model After Updated : " + this.binaryBatchModel.toString());
            LOG.info(String.format("Updated Data [%d,%d] ", Integer.valueOf(this.binaryBatchModel.getX().length), Integer.valueOf(this.binaryBatchModel.getX()[0].length)));
        }
    }
}
