package edu.iu.dsc.tws.examples.ml.svm.job;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.executor.IExecutor;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.api.dataset.DataPartitionConsumer;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.ReduceAggregator;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.SVMReduce;
import edu.iu.dsc.tws.examples.ml.svm.compute.IterativeSVMCompute;
import edu.iu.dsc.tws.examples.ml.svm.compute.SVMCompute;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.streamer.InputDataStreamer;
import edu.iu.dsc.tws.examples.ml.svm.test.PredictionAggregator;
import edu.iu.dsc.tws.examples.ml.svm.test.PredictionReduceTask;
import edu.iu.dsc.tws.examples.ml.svm.test.PredictionSourceTask;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.ResultsSaver;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.task.dataobjects.DataObjectSink;
import edu.iu.dsc.tws.task.dataobjects.DataObjectSource;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import java.io.IOException;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdAdvancedRunner.class */
public class SvmSgdAdvancedRunner extends TaskWorker {
    private static final Logger LOG = Logger.getLogger(SvmSgdAdvancedRunner.class.getName());
    private static final double NANO_TO_SEC = 1.0E9d;
    private static final double B2MB = 1048576.0d;
    private OperationMode operationMode;
    private SVMJobParameters svmJobParameters;
    private BinaryBatchModel binaryBatchModel;
    private ComputeGraphBuilder trainingBuilder;
    private ComputeGraphBuilder testingBuilder;
    private InputDataStreamer dataStreamer;
    private SVMCompute svmCompute;
    private IterativeSVMCompute iterativeSVMCompute;
    private SVMReduce svmReduce;
    private DataObject<Object> trainingData;
    private DataObject<Object> inputWeightVector;
    private DataObject<Object> testingData;
    private DataObject<Object> testingResults;
    private DataObject<double[]> trainedWeightVector;
    private PredictionSourceTask predictionSourceTask;
    private PredictionReduceTask predictionReduceTask;
    private PredictionAggregator predictionAggregator;
    private final int reduceParallelism = 1;
    private int dataStreamerParallelism = 4;
    private int svmComputeParallelism = 4;
    private int features = 10;
    private double dataLoadingTime = 0.0d;
    private double trainingTime = 0.0d;
    private double testingTime = 0.0d;
    private double accuracy = 0.0d;
    private boolean debug = false;
    private String experimentName = "";

    public void execute() {
        initializeParameters();
        initializeExecute();
    }

    public void initializeParameters() {
        this.svmJobParameters = SVMJobParameters.build(this.config);
        this.binaryBatchModel = new BinaryBatchModel();
        this.dataStreamerParallelism = this.svmJobParameters.getParallelism();
        this.experimentName = this.svmJobParameters.getExperimentName();
        this.svmComputeParallelism = this.dataStreamerParallelism;
        this.binaryBatchModel.setIterations(this.svmJobParameters.getIterations());
        this.binaryBatchModel.setAlpha(this.svmJobParameters.getAlpha());
        this.binaryBatchModel.setFeatures(this.svmJobParameters.getFeatures());
        this.binaryBatchModel.setSamples(this.svmJobParameters.getSamples());
        this.binaryBatchModel.setW(DataUtils.seedDoubleArray(this.svmJobParameters.getFeatures()));
        LOG.info(this.binaryBatchModel.toString());
    }

    public void initializeExecute() {
        this.trainingBuilder = ComputeGraphBuilder.newBuilder(this.config);
        this.testingBuilder = ComputeGraphBuilder.newBuilder(this.config);
        this.operationMode = this.svmJobParameters.isStreaming() ? OperationMode.STREAMING : OperationMode.BATCH;
        this.inputWeightVector = executeWeightVectorLoadingTaskGraph();
        Long valueOf = Long.valueOf(System.nanoTime());
        this.trainingData = executeTrainingDataLoadingTaskGraph();
        this.dataLoadingTime = (System.nanoTime() - valueOf.longValue()) / 1.0E9d;
        Long valueOf2 = Long.valueOf(System.nanoTime());
        this.trainedWeightVector = executeTrainingGraph();
        this.trainingTime = (System.nanoTime() - valueOf2.longValue()) / 1.0E9d;
        Long valueOf3 = Long.valueOf(System.nanoTime());
        this.testingData = executeTestingDataLoadingTaskGraph();
        this.dataLoadingTime += (System.nanoTime() - valueOf3.longValue()) / 1.0E9d;
        printTaskSummary();
        if (this.workerId == 0) {
            try {
                saveResults();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        LOG.info(String.format("Rank[%d] Total Memory %f MB, Max Memory %f MB", Integer.valueOf(this.workerId), Double.valueOf(Runtime.getRuntime().totalMemory() / 1048576.0d), Double.valueOf(Runtime.getRuntime().maxMemory() / 1048576.0d)));
        if (this.operationMode.equals(OperationMode.STREAMING)) {
            LOG.info("Not Yet Implemented");
        }
    }

    public DataObject<Object> executeTrainingDataLoadingTaskGraph() {
        DataObjectSource dataObjectSource = new DataObjectSource("direct", this.svmJobParameters.getTrainingDataDir());
        DataObjectSink dataObjectSink = new DataObjectSink();
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE, dataObjectSource, this.dataStreamerParallelism);
        this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.DATA_OBJECT_SINK, dataObjectSink, this.dataStreamerParallelism).direct(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE).viaEdge("direct").withDataType(MessageTypes.OBJECT);
        this.trainingBuilder.setMode(OperationMode.BATCH);
        ComputeGraph build = this.trainingBuilder.build();
        build.setGraphName("training-data-loading-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.execute(build, plan);
        DataObject<Object> output = this.taskExecutor.getOutput(build, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK);
        if (output == null) {
            throw new NullPointerException("Something Went Wrong in Loading Training Data");
        }
        LOG.info("Training Data Total Partitions : " + output.getPartitions().length);
        return output;
    }

    public DataObject<Object> executeWeightVectorLoadingTaskGraph() {
        DataObjectSource dataObjectSource = new DataObjectSource("direct", this.svmJobParameters.getWeightVectorDataDir());
        DataObjectSink dataObjectSink = new DataObjectSink();
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE, dataObjectSource, this.dataStreamerParallelism);
        this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.DATA_OBJECT_SINK, dataObjectSink, this.dataStreamerParallelism).direct(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE).viaEdge("direct").withDataType(MessageTypes.OBJECT);
        this.trainingBuilder.setMode(OperationMode.BATCH);
        ComputeGraph build = this.trainingBuilder.build();
        build.setGraphName("weight-vector-loading-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.execute(build, plan);
        DataObject<Object> output = this.taskExecutor.getOutput(build, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK);
        if (output == null) {
            throw new NullPointerException("Something Went Wrong in Loading Weight Vector");
        }
        LOG.info("Training Data Total Partitions : " + output.getPartitions().length);
        return output;
    }

    public DataObject<Object> executeTestingDataLoadingTaskGraph() {
        DataObjectSource dataObjectSource = new DataObjectSource("direct2", this.svmJobParameters.getTestingDataDir());
        DataObjectSink dataObjectSink = new DataObjectSink();
        this.testingBuilder.addSource(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE_TESTING, dataObjectSource, this.dataStreamerParallelism);
        this.testingBuilder.addCompute(Constants.SimpleGraphConfig.DATA_OBJECT_SINK_TESTING, dataObjectSink, this.dataStreamerParallelism).direct(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE_TESTING).viaEdge("direct2").withDataType(MessageTypes.OBJECT);
        this.testingBuilder.setMode(OperationMode.BATCH);
        ComputeGraph build = this.testingBuilder.build();
        build.setGraphName("testing-data-loading-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.execute(build, plan);
        DataObject<Object> output = this.taskExecutor.getOutput(build, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK_TESTING);
        if (output == null) {
            throw new NullPointerException("Something Went Wrong in Loading Testing Data");
        }
        LOG.info("Testing Data Total Partitions : " + output.getPartitions().length);
        return output;
    }

    public DataObject<double[]> executeTrainingGraph() {
        DataObject<double[]> dataObject = null;
        this.dataStreamer = new InputDataStreamer(this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        this.svmCompute = new SVMCompute(this.binaryBatchModel, this.operationMode);
        this.svmReduce = new SVMReduce(this.operationMode);
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, this.dataStreamer, this.dataStreamerParallelism);
        ComputeConnection addCompute = this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.SVM_COMPUTE, this.svmCompute, this.svmComputeParallelism);
        ComputeConnection addCompute2 = this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.SVM_REDUCE, this.svmReduce, 1);
        addCompute.direct(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.DATA_EDGE).withDataType(MessageTypes.OBJECT);
        addCompute2.allreduce(Constants.SimpleGraphConfig.SVM_COMPUTE).viaEdge(Constants.SimpleGraphConfig.REDUCE_EDGE).withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.OBJECT);
        this.trainingBuilder.setMode(this.operationMode);
        ComputeGraph build = this.trainingBuilder.build();
        build.setGraphName("training-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_DATA, this.trainingData);
        this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.inputWeightVector);
        this.taskExecutor.execute(build, plan);
        LOG.info("Task Graph Executed !!! ");
        if (this.workerId == 0) {
            dataObject = retrieveWeightVectorFromTaskGraph(build, plan);
            this.trainedWeightVector = dataObject;
        }
        return dataObject;
    }

    public DataObject<double[]> executeIterativeTrainingGraph() {
        DataObject<double[]> dataObject = null;
        this.dataStreamer = new InputDataStreamer(this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        this.iterativeSVMCompute = new IterativeSVMCompute(this.binaryBatchModel, this.operationMode);
        this.svmReduce = new SVMReduce(this.operationMode);
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, this.dataStreamer, this.dataStreamerParallelism);
        ComputeConnection addCompute = this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.SVM_COMPUTE, this.iterativeSVMCompute, this.svmComputeParallelism);
        ComputeConnection addCompute2 = this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.SVM_REDUCE, this.svmReduce, 1);
        addCompute.direct(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.DATA_EDGE).withDataType(MessageTypes.OBJECT);
        addCompute2.allreduce(Constants.SimpleGraphConfig.SVM_COMPUTE).viaEdge(Constants.SimpleGraphConfig.REDUCE_EDGE).withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.OBJECT);
        this.trainingBuilder.setMode(this.operationMode);
        ComputeGraph build = this.trainingBuilder.build();
        build.setGraphName("training-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        IExecutor createExecution = this.taskExecutor.createExecution(build, plan);
        for (int i = 0; i < this.binaryBatchModel.getIterations(); i++) {
            this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_DATA, this.trainingData);
            this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.inputWeightVector);
            this.inputWeightVector = this.taskExecutor.getOutput(build, plan, Constants.SimpleGraphConfig.SVM_REDUCE);
            createExecution.execute();
        }
        createExecution.closeExecution();
        LOG.info("Task Graph Executed !!! ");
        if (this.workerId == 0) {
            dataObject = retrieveWeightVectorFromTaskGraph(build, plan);
            this.trainedWeightVector = dataObject;
        }
        return dataObject;
    }

    public DataObject<double[]> retrieveWeightVectorFromTaskGraph(ComputeGraph computeGraph, ExecutionPlan executionPlan) {
        DataObject<double[]> output = this.taskExecutor.getOutput(computeGraph, executionPlan, Constants.SimpleGraphConfig.SVM_REDUCE);
        if (this.debug) {
            LOG.info(String.format("Number of Partitions : %d ", Integer.valueOf(output.getPartitions().length)));
        }
        DataPartitionConsumer consumer = output.getPartitions()[0].getConsumer();
        while (consumer.hasNext()) {
            LOG.info("Final Weight Vector:" + Arrays.toString((double[]) consumer.next()));
        }
        if (output == null) {
            throw new NullPointerException(" Something went wrong in retrieving trained weights from training task graph");
        }
        return output;
    }

    public DataObject<Object> executeTestingTaskGraph() {
        this.predictionSourceTask = new PredictionSourceTask(this.svmJobParameters.isDummy(), this.binaryBatchModel, this.operationMode);
        this.predictionReduceTask = new PredictionReduceTask(this.operationMode);
        this.testingBuilder.addSource(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, this.predictionSourceTask, this.dataStreamerParallelism);
        this.testingBuilder.addCompute(Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK, this.predictionReduceTask, 1).reduce(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK).viaEdge(Constants.SimpleGraphConfig.PREDICTION_EDGE).withReductionFunction(new PredictionAggregator()).withDataType(MessageTypes.OBJECT);
        this.testingBuilder.setMode(this.operationMode);
        ComputeGraph build = this.testingBuilder.build();
        build.setGraphName("testing-graph");
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, "test_data", this.testingData);
        this.taskExecutor.addInput(build, plan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.FINAL_WEIGHT_VECTOR, this.trainedWeightVector);
        this.taskExecutor.execute(build, plan);
        return retrieveTestingAccuracyObject(build, plan);
    }

    public DataObject<Object> retrieveTestingAccuracyObject(ComputeGraph computeGraph, ExecutionPlan executionPlan) {
        return this.taskExecutor.getOutput(computeGraph, executionPlan, Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK);
    }

    public double retriveFinalTestingAccuracy(DataObject<Object> dataObject) {
        double d = 0.0d;
        Object next = dataObject.getPartitions()[0].getConsumer().next();
        if (next instanceof Double) {
            d = ((Double) next).doubleValue() / this.dataStreamerParallelism;
            LOG.info(String.format("Testing Accuracy  : %f ", Double.valueOf(d)));
        } else {
            LOG.severe("Something Went Wrong In Calculating Testing Accuracy");
        }
        return d;
    }

    public void printTaskSummary() {
        LOG.info(String.format((((((((((((("\n\n======================================================================================\n") + "\t\t\tSVM Task Summary : [" + this.experimentName + "]\n") + "======================================================================================\n") + "Training Dataset [" + this.svmJobParameters.getTrainingDataDir() + "] \n") + "Testing  Dataset [" + this.svmJobParameters.getTestingDataDir() + "] \n") + "Total Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Maximum Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Data Loading Time (Training + Testing) \t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.dataLoadingTime)) + "  s \n") + "Training Time \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.trainingTime)) + "  s \n") + "Testing Time  \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.testingTime)) + "  s \n") + "Total Time (Data Loading Time + Training Time + Testing Time) \t=" + String.format(" %.9f", Double.valueOf(this.dataLoadingTime + this.trainingTime + this.testingTime)) + "  s \n") + String.format("Accuracy of the Trained Model \t\t\t\t\t= %2.9f", Double.valueOf(this.accuracy)) + " %%\n") + "======================================================================================\n", new Object[0]));
    }

    public void saveResults() throws IOException {
        new ResultsSaver(this.trainingTime, this.testingTime, this.dataLoadingTime, this.dataLoadingTime + this.trainingTime + this.testingTime, this.svmJobParameters, Constants.SimpleGraphConfig.TASK_RUNNER).save();
    }
}
