package edu.iu.dsc.tws.examples.ml.svm.compute;

import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseCompute;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.api.dataset.DataPartition;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.integration.test.ICollector;
import edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor;
import edu.iu.dsc.tws.examples.ml.svm.math.Matrix;
import edu.iu.dsc.tws.examples.ml.svm.util.MLUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.examples.ml.svm.util.TrainedModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/compute/IterativeStreamingCompute.class */
public class IterativeStreamingCompute extends BaseCompute<double[]> implements ICollector<double[]>, IReceptor<double[][]> {
    private static final long serialVersionUID = 332173590941256461L;
    private static final Logger LOG = Logger.getLogger(IterativeStreamingCompute.class.getName());
    private double[] newWeightVector;
    private OperationMode operationMode;
    private IFunction<double[]> reduceFn;
    private SVMJobParameters svmJobParameters;
    private TrainedModel trainedModel;
    private List<double[]> aggregatedModels = new ArrayList();
    private boolean debug = false;
    private int evaluationInterval = 10;
    private DataObject<double[][]> dataPointsObject = null;
    private double[][] datapoints = null;

    public IterativeStreamingCompute(OperationMode operationMode) {
        this.operationMode = operationMode;
    }

    public IterativeStreamingCompute(OperationMode operationMode, IFunction<double[]> iFunction) {
        this.operationMode = operationMode;
        this.reduceFn = iFunction;
    }

    public IterativeStreamingCompute(OperationMode operationMode, IFunction<double[]> iFunction, SVMJobParameters sVMJobParameters) {
        this.operationMode = operationMode;
        this.reduceFn = iFunction;
        this.svmJobParameters = sVMJobParameters;
    }

    public void prepare(Config config, TaskContext taskContext) {
        super.prepare(config, taskContext);
        prepareDataPoints();
        LOG.info(String.format("Test Data Size : %d ", Integer.valueOf(this.datapoints.length)));
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.integration.test.ICollector
    public DataPartition<double[]> get() {
        return new EntityPartition(this.context.taskIndex(), this.newWeightVector);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.iu.dsc.tws.examples.ml.svm.integration.test.IReceptor
    public void add(String str, DataObject<?> dataObject) {
        if ("test_data".equals(str)) {
            this.dataPointsObject = dataObject;
        }
    }

    private void prepareDataPoints() {
        this.datapoints = (double[][]) this.dataPointsObject.getPartition(this.context.taskIndex()).getConsumer().next();
        if (this.debug) {
            LOG.info(String.format("Recieved Input Data : %s ", this.datapoints.getClass().getName()));
        }
        LOG.info(String.format("Data Point TaskIndex[%d], Size : %d ", Integer.valueOf(this.context.taskIndex()), Integer.valueOf(this.datapoints.length)));
    }

    public boolean execute(IMessage<double[]> iMessage) {
        if (iMessage.getContent() == null) {
            LOG.info("Something Went Wrong !!!");
            return true;
        }
        if (this.debug) {
            LOG.info(String.format("Received Sink Value : %d, %s, %f", Integer.valueOf(this.newWeightVector.length), Arrays.toString(this.newWeightVector), Double.valueOf(this.newWeightVector[0])));
        }
        this.newWeightVector = (double[]) iMessage.getContent();
        this.aggregatedModels.add(this.newWeightVector);
        double[] dArr = new double[this.aggregatedModels.get(0).length];
        int size = this.aggregatedModels.size();
        for (int i = 0; i < size; i++) {
            dArr = Matrix.scalarDivide((double[]) this.reduceFn.onMessage(dArr, this.aggregatedModels.get(i)), size);
        }
        evaluateModel(dArr, size);
        return true;
    }

    public void evaluateModel(double[] dArr, int i) {
        try {
            this.trainedModel = MLUtils.predictSGDSVM(dArr, this.datapoints, this.svmJobParameters, "final-model");
        } catch (MatrixMultiplicationException e) {
            LOG.severe(String.format("MatrixMultiplicationException : " + e.getMessage(), new Object[0]));
        }
        if (this.debug) {
            LOG.info(String.format("Evaluation TimeStamp [%d] Model : %s, Accuracy : %f", Integer.valueOf(i), Arrays.toString(dArr), Double.valueOf(this.trainedModel.getAccuracy())));
        }
        this.context.write("window-evaluation-edge", Double.valueOf(this.trainedModel.getAccuracy()));
    }
}
