package edu.iu.dsc.tws.examples.ml.svm.job;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeAccuracyReduceFunction;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeSVMAccuracyReduce;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeSVMWeightVectorReduce;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.ReduceAggregator;
import edu.iu.dsc.tws.examples.ml.svm.compute.IterativeStreamingCompute;
import edu.iu.dsc.tws.examples.ml.svm.compute.window.IterativeStreamingSinkEvaluator;
import edu.iu.dsc.tws.examples.ml.svm.compute.window.IterativeStreamingWindowedCompute;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.constant.IterativeSVMConstants;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.math.Matrix;
import edu.iu.dsc.tws.examples.ml.svm.streamer.IterativeDataStream;
import edu.iu.dsc.tws.examples.ml.svm.streamer.IterativePredictionDataStreamer;
import edu.iu.dsc.tws.examples.ml.svm.streamer.IterativeStreamingDataStreamer;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.examples.ml.svm.util.TGUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.TrainedModel;
import edu.iu.dsc.tws.examples.ml.svm.util.WindowArguments;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import edu.iu.dsc.tws.task.window.api.IWindowMessage;
import edu.iu.dsc.tws.task.window.constant.WindowType;
import edu.iu.dsc.tws.task.window.core.BaseWindowedSink;
import edu.iu.dsc.tws.task.window.function.ProcessWindowedFunction;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdOnlineRunner.class */
public class SvmSgdOnlineRunner extends TaskWorker {
    private static final String DELIMITER = ",";
    private OperationMode operationMode;
    private SVMJobParameters svmJobParameters;
    private BinaryBatchModel binaryBatchModel;
    private ComputeGraphBuilder trainingBuilder;
    private ComputeGraphBuilder testingBuilder;
    private ComputeGraph iterativeSVMTrainingTaskGraph;
    private ExecutionPlan iterativeSVMTrainingExecutionPlan;
    private ComputeGraph iterativeSVMTestingTaskGraph;
    private ExecutionPlan iterativeSVMTestingExecutionPlan;
    private ComputeGraph weightVectorTaskGraph;
    private ExecutionPlan weightVectorExecutionPlan;
    private IterativeDataStream iterativeDataStream;
    private IterativeStreamingDataStreamer iterativeStreamingDataStreamer;
    private IterativeStreamingCompute iterativeStreamingCompute;
    private IterativeStreamingWindowedCompute iterativeStreamingWindowedCompute;
    private IterativePredictionDataStreamer iterativePredictionDataStreamer;
    private IterativeSVMAccuracyReduce iterativeSVMAccuracyReduce;
    private IterativeSVMWeightVectorReduce iterativeSVMRiterativeSVMWeightVectorReduce;
    private DataObject<double[][]> trainingDoubleDataPointObject;
    private DataObject<double[][]> testingDoubleDataPointObject;
    private DataObject<double[]> inputDoubleWeightvectorObject;
    private DataObject<double[]> currentDataPoint;
    private DataObject<Double> finalAccuracyDoubleObject;
    private static final Logger LOG = Logger.getLogger(SvmSgdOnlineRunner.class.getName());
    private static int count = 0;
    private int dataStreamerParallelism = 4;
    private int svmComputeParallelism = 4;
    private int features = 10;
    private long initializingTime = 0;
    private long dataLoadingTime = 0;
    private long trainingTime = 0;
    private long testingTime = 0;
    private double initializingDTime = 0.0d;
    private double dataLoadingDTime = 0.0d;
    private double trainingDTime = 0.0d;
    private double testingDTime = 0.0d;
    private double totalDTime = 0.0d;
    private double accuracy = 0.0d;
    private boolean debug = false;
    private String experimentName = "";

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdOnlineRunner$ProcessWindowFunctionImpl.class */
    public static class ProcessWindowFunctionImpl implements ProcessWindowedFunction<double[]> {
        private static final long serialVersionUID = 8517840191276879034L;
        private static final Logger LOG = Logger.getLogger(ProcessWindowFunctionImpl.class.getName());

        protected ProcessWindowFunctionImpl() {
        }

        public IWindowMessage<double[]> process(IWindowMessage<double[]> iWindowMessage) {
            return iWindowMessage;
        }

        public IMessage<double[]> processLateMessage(IMessage<double[]> iMessage) {
            return iMessage;
        }

        public double[] onMessage(double[] dArr, double[] dArr2) {
            try {
                return Matrix.add(dArr, dArr2);
            } catch (MatrixMultiplicationException e) {
                LOG.severe(String.format("Math Error : %s", e.getMessage()));
                return null;
            }
        }
    }

    public void execute() {
        initialize().paramCheck().loadData().stream().summary();
    }

    public void initializeParameters() {
        this.svmJobParameters = SVMJobParameters.build(this.config);
        this.binaryBatchModel = new BinaryBatchModel();
        this.dataStreamerParallelism = this.svmJobParameters.getParallelism();
        this.experimentName = this.svmJobParameters.getExperimentName();
        this.svmComputeParallelism = this.dataStreamerParallelism;
        this.features = this.svmJobParameters.getFeatures();
        this.binaryBatchModel.setIterations(this.svmJobParameters.getIterations());
        this.binaryBatchModel.setAlpha(this.svmJobParameters.getAlpha());
        this.binaryBatchModel.setFeatures(this.svmJobParameters.getFeatures());
        this.binaryBatchModel.setSamples(this.svmJobParameters.getSamples());
        this.binaryBatchModel.setW(DataUtils.seedDoubleArray(this.svmJobParameters.getFeatures()));
        LOG.info(this.binaryBatchModel.toString());
        this.operationMode = this.svmJobParameters.isStreaming() ? OperationMode.STREAMING : OperationMode.BATCH;
        this.trainingBuilder = ComputeGraphBuilder.newBuilder(this.config);
        this.testingBuilder = ComputeGraphBuilder.newBuilder(this.config);
    }

    public SvmSgdOnlineRunner initialize() {
        long nanoTime = System.nanoTime();
        initializeParameters();
        this.initializingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdOnlineRunner loadData() {
        long nanoTime = System.nanoTime();
        loadTrainingData();
        loadTestingData();
        this.dataLoadingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdOnlineRunner withWeightVector() {
        long nanoTime = System.nanoTime();
        loadWeightVector();
        this.dataLoadingTime += System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdOnlineRunner summary() {
        generateSummary();
        return this;
    }

    public SvmSgdOnlineRunner stream() {
        withWeightVector();
        streamData();
        return this;
    }

    public SvmSgdOnlineRunner paramCheck() {
        LOG.info(String.format("Info : %s ", this.svmJobParameters.toString()));
        return this;
    }

    private void loadTrainingData() {
        ComputeGraph buildTrainingDataPointsTG = TGUtils.buildTrainingDataPointsTG(this.dataStreamerParallelism, this.svmJobParameters, this.config, this.operationMode);
        ExecutionPlan plan = this.taskExecutor.plan(buildTrainingDataPointsTG);
        this.taskExecutor.execute(buildTrainingDataPointsTG, plan);
        this.trainingDoubleDataPointObject = this.taskExecutor.getOutput(buildTrainingDataPointsTG, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK);
        for (int i = 0; i < this.trainingDoubleDataPointObject.getPartitions().length; i++) {
            double[][] dArr = (double[][]) this.trainingDoubleDataPointObject.getPartitions()[i].getConsumer().next();
            LOG.info(String.format("Training Datapoints : %d,%d", Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length)));
            int nextInt = new Random().nextInt((this.svmJobParameters.getSamples() / this.dataStreamerParallelism) - 1);
            LOG.info(String.format("Random DataPoint[%d] : %s", Integer.valueOf(nextInt), Arrays.toString(dArr[nextInt])));
        }
    }

    private void loadTestingData() {
        ComputeGraph buildTestingDataPointsTG = TGUtils.buildTestingDataPointsTG(this.dataStreamerParallelism, this.svmJobParameters, this.config, this.operationMode);
        ExecutionPlan plan = this.taskExecutor.plan(buildTestingDataPointsTG);
        this.taskExecutor.execute(buildTestingDataPointsTG, plan);
        this.testingDoubleDataPointObject = this.taskExecutor.getOutput(buildTestingDataPointsTG, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK_TESTING);
        for (int i = 0; i < this.testingDoubleDataPointObject.getPartitions().length; i++) {
            double[][] dArr = (double[][]) this.testingDoubleDataPointObject.getPartitions()[i].getConsumer().next();
            LOG.info(String.format("Partition[%d] Testing Datapoints : %d,%d", Integer.valueOf(i), Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length)));
            int nextInt = new Random().nextInt((this.svmJobParameters.getTestingSamples() / this.dataStreamerParallelism) - 1);
            LOG.info(String.format("Random DataPoint[%d] : %s", Integer.valueOf(nextInt), Arrays.toString(dArr[nextInt])));
        }
    }

    private void loadWeightVector() {
        this.weightVectorTaskGraph = TGUtils.buildWeightVectorTG(this.config, this.dataStreamerParallelism, this.svmJobParameters, this.operationMode);
        this.weightVectorExecutionPlan = this.taskExecutor.plan(this.weightVectorTaskGraph);
        this.taskExecutor.execute(this.weightVectorTaskGraph, this.weightVectorExecutionPlan);
        this.inputDoubleWeightvectorObject = this.taskExecutor.getOutput(this.weightVectorTaskGraph, this.weightVectorExecutionPlan, Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SINK);
        LOG.info(String.format("Weight Vector Loaded : %s", Arrays.toString((double[]) this.inputDoubleWeightvectorObject.getPartitions()[0].getConsumer().next())));
    }

    private void generateSummary() {
        double d = Runtime.getRuntime().totalMemory() / 1048576.0d;
        convert2Seconds();
        this.totalDTime = this.initializingDTime + this.dataLoadingDTime + this.trainingDTime + this.testingDTime;
        LOG.info(String.format((((((((((((("\n\n======================================================================================\n") + "\t\t\tIterative SVM Task Summary : [" + this.experimentName + "]\n") + "======================================================================================\n") + "Training Dataset [" + this.svmJobParameters.getTrainingDataDir() + "] \n") + "Testing  Dataset [" + this.svmJobParameters.getTestingDataDir() + "] \n") + "Total Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Maximum Memory [ " + d + " MB] \n") + "Data Loading Time (Training + Testing) \t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.dataLoadingDTime)) + "  s \n") + "Training Time \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.trainingDTime)) + "  s \n") + "Testing Time  \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.testingDTime)) + "  s \n") + "Total Time (Data Loading Time + Training Time + Testing Time) \t=" + String.format(" %.9f", Double.valueOf(this.totalDTime)) + "  s \n") + String.format("Accuracy of the Trained Model \t\t\t\t\t= %2.9f", Double.valueOf(this.accuracy)) + " %%\n") + "======================================================================================\n", new Object[0]));
        save();
    }

    private void save() {
        new TrainedModel(this.binaryBatchModel, this.accuracy, this.trainingTime, this.svmJobParameters.getExperimentName() + "-online", this.svmJobParameters.getParallelism()).saveModel(this.svmJobParameters.getModelSaveDir());
    }

    private void convert2Seconds() {
        this.initializingDTime = this.initializingTime / 1.0E9d;
        this.dataLoadingDTime = this.dataLoadingTime / 1.0E9d;
        this.trainingDTime = this.trainingTime / 1.0E9d;
        this.testingDTime = this.testingTime / 1.0E9d;
    }

    private void streamData() {
        ComputeGraph buildStreamingTrainingTG = buildStreamingTrainingTG();
        ExecutionPlan plan = this.taskExecutor.plan(buildStreamingTrainingTG);
        this.taskExecutor.addInput(buildStreamingTrainingTG, plan, Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_DATA, this.trainingDoubleDataPointObject);
        this.taskExecutor.addInput(buildStreamingTrainingTG, plan, Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.inputDoubleWeightvectorObject);
        this.taskExecutor.addInput(buildStreamingTrainingTG, plan, "window-sink", "test_data", this.testingDoubleDataPointObject);
        this.taskExecutor.execute(buildStreamingTrainingTG, plan);
    }

    private ComputeGraph buildStreamingTrainingTG() {
        this.iterativeStreamingDataStreamer = new IterativeStreamingDataStreamer(this.svmJobParameters.getFeatures(), OperationMode.STREAMING, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        BaseWindowedSink windowSinkInstance = getWindowSinkInstance();
        this.iterativeStreamingCompute = new IterativeStreamingCompute(OperationMode.STREAMING, new ReduceAggregator(), this.svmJobParameters);
        IterativeStreamingSinkEvaluator iterativeStreamingSinkEvaluator = new IterativeStreamingSinkEvaluator();
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE, this.iterativeStreamingDataStreamer, this.dataStreamerParallelism);
        ComputeConnection addCompute = this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE, windowSinkInstance, this.dataStreamerParallelism);
        ComputeConnection addCompute2 = this.trainingBuilder.addCompute("window-sink", this.iterativeStreamingCompute, this.dataStreamerParallelism);
        ComputeConnection addCompute3 = this.trainingBuilder.addCompute("window-evaluation-sink", iterativeStreamingSinkEvaluator, this.dataStreamerParallelism);
        addCompute.direct(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.STREAMING_EDGE).withDataType(MessageTypes.DOUBLE_ARRAY);
        addCompute2.allreduce(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE).viaEdge("window-sink-edge").withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.DOUBLE_ARRAY);
        addCompute3.allreduce("window-sink").viaEdge("window-evaluation-edge").withReductionFunction(new IterativeAccuracyReduceFunction()).withDataType(MessageTypes.DOUBLE);
        this.trainingBuilder.setMode(OperationMode.STREAMING);
        this.trainingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_STREAMING_TRAINING_TASK_GRAPH);
        return this.trainingBuilder.build();
    }

    private BaseWindowedSink getWindowSinkInstance() {
        IterativeStreamingWindowedCompute iterativeStreamingWindowedCompute = new IterativeStreamingWindowedCompute(new ProcessWindowFunctionImpl(), OperationMode.STREAMING, this.svmJobParameters, this.binaryBatchModel, "online-training-graph");
        WindowArguments windowArguments = this.svmJobParameters.getWindowArguments();
        TimeUnit timeUnit = TimeUnit.MICROSECONDS;
        if (windowArguments != null) {
            WindowType windowType = windowArguments.getWindowType();
            if (windowArguments.isDuration()) {
                if (windowType.equals(WindowType.TUMBLING)) {
                    iterativeStreamingWindowedCompute.withTumblingDurationWindow(windowArguments.getWindowLength(), timeUnit);
                }
                if (windowType.equals(WindowType.SLIDING)) {
                    iterativeStreamingWindowedCompute.withSlidingDurationWindow(windowArguments.getWindowLength(), timeUnit, windowArguments.getSlidingLength(), timeUnit);
                }
            } else {
                if (windowType.equals(WindowType.TUMBLING)) {
                    iterativeStreamingWindowedCompute.withTumblingCountWindow(windowArguments.getWindowLength());
                }
                if (windowType.equals(WindowType.SLIDING)) {
                    iterativeStreamingWindowedCompute.withSlidingCountWindow(windowArguments.getWindowLength(), windowArguments.getSlidingLength());
                }
            }
        }
        return iterativeStreamingWindowedCompute;
    }
}
