package edu.iu.dsc.tws.examples.ml.svm.job;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.executor.IExecutor;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeAccuracyReduceFunction;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeSVMAccuracyReduce;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeSVMWeightVectorReduce;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeWeightVectorReduceFunction;
import edu.iu.dsc.tws.examples.ml.svm.config.DataPartitionType;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.constant.IterativeSVMConstants;
import edu.iu.dsc.tws.examples.ml.svm.data.DataPartitioner;
import edu.iu.dsc.tws.examples.ml.svm.data.IterativeSVMDataObjectCompute;
import edu.iu.dsc.tws.examples.ml.svm.data.IterativeSVMDataObjectDirectSink;
import edu.iu.dsc.tws.examples.ml.svm.data.IterativeSVMWeightVectorObjectCompute;
import edu.iu.dsc.tws.examples.ml.svm.data.IterativeSVMWeightVectorObjectDirectSink;
import edu.iu.dsc.tws.examples.ml.svm.data.SVMDataObjectSource;
import edu.iu.dsc.tws.examples.ml.svm.math.Matrix;
import edu.iu.dsc.tws.examples.ml.svm.streamer.IterativeDataStream;
import edu.iu.dsc.tws.examples.ml.svm.streamer.IterativePredictionDataStreamer;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.examples.ml.svm.util.TrainedModel;
import edu.iu.dsc.tws.task.dataobjects.DataFileReplicatedReadSource;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdIterativeRunner.class */
public class SvmSgdIterativeRunner extends TaskWorker {
    private static final Logger LOG = Logger.getLogger(SvmSgdIterativeRunner.class.getName());
    private static final String DELIMITER = ",";
    private OperationMode operationMode;
    private SVMJobParameters svmJobParameters;
    private BinaryBatchModel binaryBatchModel;
    private ComputeGraphBuilder trainingBuilder;
    private ComputeGraphBuilder testingBuilder;
    private ComputeGraph iterativeSVMTrainingTaskGraph;
    private ExecutionPlan iterativeSVMTrainingExecutionPlan;
    private ComputeGraph iterativeSVMTestingTaskGraph;
    private ExecutionPlan iterativeSVMTestingExecutionPlan;
    private ComputeGraph weightVectorTaskGraph;
    private ExecutionPlan weightVectorExecutionPlan;
    private IterativeDataStream iterativeDataStream;
    private IterativePredictionDataStreamer iterativePredictionDataStreamer;
    private IterativeSVMAccuracyReduce iterativeSVMAccuracyReduce;
    private IterativeSVMWeightVectorReduce iterativeSVMRiterativeSVMWeightVectorReduce;
    private DataObject<double[][]> trainingDoubleDataPointObject;
    private DataObject<double[][]> testingDoubleDataPointObject;
    private DataObject<double[]> inputDoubleWeightvectorObject;
    private DataObject<Double> finalAccuracyDoubleObject;
    private int dataStreamerParallelism = 4;
    private int svmComputeParallelism = 4;
    private int features = 10;
    private long initializingTime = 0;
    private long dataLoadingTime = 0;
    private long trainingTime = 0;
    private long testingTime = 0;
    private double initializingDTime = 0.0d;
    private double dataLoadingDTime = 0.0d;
    private double trainingDTime = 0.0d;
    private double testingDTime = 0.0d;
    private double totalDTime = 0.0d;
    private double accuracy = 0.0d;
    private boolean debug = false;
    private String experimentName = "";

    public void execute() {
        initialize().loadData().train().predict().summary();
    }

    public void initializeParameters() {
        this.svmJobParameters = SVMJobParameters.build(this.config);
        this.binaryBatchModel = new BinaryBatchModel();
        this.dataStreamerParallelism = this.svmJobParameters.getParallelism();
        this.experimentName = this.svmJobParameters.getExperimentName();
        this.svmComputeParallelism = this.dataStreamerParallelism;
        this.features = this.svmJobParameters.getFeatures();
        this.binaryBatchModel.setIterations(this.svmJobParameters.getIterations());
        this.binaryBatchModel.setAlpha(this.svmJobParameters.getAlpha());
        this.binaryBatchModel.setFeatures(this.svmJobParameters.getFeatures());
        this.binaryBatchModel.setSamples(this.svmJobParameters.getSamples());
        this.binaryBatchModel.setW(DataUtils.seedDoubleArray(this.svmJobParameters.getFeatures()));
        LOG.info(this.binaryBatchModel.toString());
        this.operationMode = this.svmJobParameters.isStreaming() ? OperationMode.STREAMING : OperationMode.BATCH;
        this.trainingBuilder = ComputeGraphBuilder.newBuilder(this.config);
        this.testingBuilder = ComputeGraphBuilder.newBuilder(this.config);
    }

    public SvmSgdIterativeRunner initialize() {
        long nanoTime = System.nanoTime();
        initializeParameters();
        this.initializingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdIterativeRunner loadData() {
        long nanoTime = System.nanoTime();
        loadTrainingData();
        loadTestingData();
        this.dataLoadingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdIterativeRunner withWeightVector() {
        long nanoTime = System.nanoTime();
        loadWeightVector();
        this.dataLoadingTime += System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdIterativeRunner train() {
        withWeightVector();
        long nanoTime = System.nanoTime();
        runTrainingGraph();
        this.trainingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdIterativeRunner predict() {
        long nanoTime = System.nanoTime();
        runPredictionGraph();
        this.testingTime = System.nanoTime() - nanoTime;
        return this;
    }

    public SvmSgdIterativeRunner summary() {
        generateSummary();
        return this;
    }

    private void testDataPartitionLogic() {
        LOG.info(String.format("Map Info : %s", new DataPartitioner().withParallelism(this.dataStreamerParallelism).withSamples(this.svmJobParameters.getSamples()).withPartitionType(DataPartitionType.DEFAULT).withImbalancePartitionId(1).partition().getDataPartitionMap().toString()));
    }

    private void loadWeightVector() {
        this.weightVectorTaskGraph = buildWeightVectorTG();
        this.weightVectorExecutionPlan = this.taskExecutor.plan(this.weightVectorTaskGraph);
        this.taskExecutor.execute(this.weightVectorTaskGraph, this.weightVectorExecutionPlan);
        this.inputDoubleWeightvectorObject = this.taskExecutor.getOutput(this.weightVectorTaskGraph, this.weightVectorExecutionPlan, Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SINK);
        LOG.info(String.format("Weight Vector Loaded : %s", Arrays.toString((double[]) this.inputDoubleWeightvectorObject.getPartitions()[0].getConsumer().next())));
    }

    private void loadTrainingData() {
        ComputeGraph buildTrainingDataPointsTG = buildTrainingDataPointsTG();
        ExecutionPlan plan = this.taskExecutor.plan(buildTrainingDataPointsTG);
        this.taskExecutor.execute(buildTrainingDataPointsTG, plan);
        this.trainingDoubleDataPointObject = this.taskExecutor.getOutput(buildTrainingDataPointsTG, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK);
        double[][] dArr = null;
        for (int i = 0; i < this.trainingDoubleDataPointObject.getPartitions().length; i++) {
            dArr = (double[][]) this.trainingDoubleDataPointObject.getPartitions()[i].getConsumer().next();
            LOG.info(String.format("Training Datapoints : %d,%d", Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length)));
        }
        System.out.println("---------Training Data-------------");
        int i2 = 0;
        for (double[] dArr2 : dArr) {
            if (Matrix.sum(dArr2) == 0.0d) {
                i2++;
            }
        }
        System.out.println(String.format("Training: %d,%d, Zero Sum: %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length), Integer.valueOf(i2)));
    }

    private void loadTestingData() {
        ComputeGraph buildTestingDataPointsTG = buildTestingDataPointsTG();
        ExecutionPlan plan = this.taskExecutor.plan(buildTestingDataPointsTG);
        this.taskExecutor.execute(buildTestingDataPointsTG, plan);
        this.testingDoubleDataPointObject = this.taskExecutor.getOutput(buildTestingDataPointsTG, plan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK_TESTING);
        double[][] dArr = null;
        for (int i = 0; i < this.testingDoubleDataPointObject.getPartitions().length; i++) {
            dArr = (double[][]) this.testingDoubleDataPointObject.getPartitions()[i].getConsumer().next();
            LOG.info(String.format("Partition[%d] Testing Datapoints : %d,%d", Integer.valueOf(i), Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length)));
        }
        System.out.println("---------Testing Data-------------");
        int i2 = 0;
        for (double[] dArr2 : dArr) {
            if (Matrix.sum(dArr2) == 0.0d) {
                i2++;
            }
        }
        System.out.println(String.format("%d,%d, Zero Sum: %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr[0].length), Integer.valueOf(i2)));
    }

    private ComputeGraph buildTrainingDataPointsTG() {
        return generateGenericDataPointLoader(this.svmJobParameters.getSamples(), this.dataStreamerParallelism, this.svmJobParameters.getFeatures(), this.svmJobParameters.getTrainingDataDir(), Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE, Constants.SimpleGraphConfig.DATA_OBJECT_COMPUTE, Constants.SimpleGraphConfig.DATA_OBJECT_SINK, IterativeSVMConstants.TRAINING_DATA_LOADING_TASK_GRAPH);
    }

    private ComputeGraph buildTestingDataPointsTG() {
        return generateGenericDataPointLoader(this.svmJobParameters.getTestingSamples(), this.dataStreamerParallelism, this.svmJobParameters.getFeatures(), this.svmJobParameters.getTestingDataDir(), Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE_TESTING, Constants.SimpleGraphConfig.DATA_OBJECT_COMPUTE_TESTING, Constants.SimpleGraphConfig.DATA_OBJECT_SINK_TESTING, IterativeSVMConstants.TESTING_DATA_LOADING_TASK_GRAPH);
    }

    private ComputeGraph generateGenericDataPointLoader(int i, int i2, int i3, String str, String str2, String str3, String str4, String str5) {
        SVMDataObjectSource sVMDataObjectSource = new SVMDataObjectSource("direct", str, i);
        IterativeSVMDataObjectCompute iterativeSVMDataObjectCompute = new IterativeSVMDataObjectCompute("direct", i2, i, i3, ",");
        IterativeSVMDataObjectDirectSink iterativeSVMDataObjectDirectSink = new IterativeSVMDataObjectDirectSink();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(this.config);
        newBuilder.addSource(str2, sVMDataObjectSource, i2);
        ComputeConnection addCompute = newBuilder.addCompute(str3, iterativeSVMDataObjectCompute, i2);
        ComputeConnection addCompute2 = newBuilder.addCompute(str4, iterativeSVMDataObjectDirectSink, i2);
        addCompute.direct(str2).viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct(str3).viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(this.operationMode);
        newBuilder.setTaskGraphName(str5);
        return newBuilder.build();
    }

    private ComputeGraph buildWeightVectorTG() {
        DataFileReplicatedReadSource dataFileReplicatedReadSource = new DataFileReplicatedReadSource("direct", this.svmJobParameters.getWeightVectorDataDir(), 1);
        IterativeSVMWeightVectorObjectCompute iterativeSVMWeightVectorObjectCompute = new IterativeSVMWeightVectorObjectCompute("direct", 1, this.svmJobParameters.getFeatures());
        IterativeSVMWeightVectorObjectDirectSink iterativeSVMWeightVectorObjectDirectSink = new IterativeSVMWeightVectorObjectDirectSink();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(this.config);
        newBuilder.addSource(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SOURCE, dataFileReplicatedReadSource, this.dataStreamerParallelism);
        ComputeConnection addCompute = newBuilder.addCompute(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_COMPUTE, iterativeSVMWeightVectorObjectCompute, this.dataStreamerParallelism);
        ComputeConnection addCompute2 = newBuilder.addCompute(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SINK, iterativeSVMWeightVectorObjectDirectSink, this.dataStreamerParallelism);
        addCompute.direct(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SOURCE).viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addCompute2.direct(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_COMPUTE).viaEdge("direct").withDataType(MessageTypes.DOUBLE_ARRAY);
        newBuilder.setMode(this.operationMode);
        newBuilder.setTaskGraphName(IterativeSVMConstants.WEIGHT_VECTOR_LOADING_TASK_GRAPH);
        return newBuilder.build();
    }

    private void runTrainingGraph() {
        this.iterativeSVMTrainingTaskGraph = buildSvmSgdIterativeTrainingTG();
        this.iterativeSVMTrainingExecutionPlan = this.taskExecutor.plan(this.iterativeSVMTrainingTaskGraph);
        IExecutor createExecution = this.taskExecutor.createExecution(this.iterativeSVMTrainingTaskGraph, this.iterativeSVMTrainingExecutionPlan);
        for (int i = 0; i < this.binaryBatchModel.getIterations(); i++) {
            LOG.info(String.format("Iteration  %d ", Integer.valueOf(i)));
            this.taskExecutor.addInput(this.iterativeSVMTrainingTaskGraph, this.iterativeSVMTrainingExecutionPlan, Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_DATA, this.trainingDoubleDataPointObject);
            this.taskExecutor.addInput(this.iterativeSVMTrainingTaskGraph, this.iterativeSVMTrainingExecutionPlan, Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE, Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.inputDoubleWeightvectorObject);
            createExecution.execute();
            this.inputDoubleWeightvectorObject = this.taskExecutor.getOutput(this.iterativeSVMTrainingTaskGraph, this.iterativeSVMTrainingExecutionPlan, Constants.SimpleGraphConfig.ITERATIVE_SVM_REDUCE);
        }
        createExecution.closeExecution();
    }

    private ComputeGraph buildSvmSgdIterativeTrainingTG() {
        this.iterativeDataStream = new IterativeDataStream(this.svmJobParameters.getFeatures(), this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        this.iterativeSVMRiterativeSVMWeightVectorReduce = new IterativeSVMWeightVectorReduce(this.operationMode);
        this.trainingBuilder.addSource(Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE, this.iterativeDataStream, this.dataStreamerParallelism);
        this.trainingBuilder.addCompute(Constants.SimpleGraphConfig.ITERATIVE_SVM_REDUCE, this.iterativeSVMRiterativeSVMWeightVectorReduce, this.dataStreamerParallelism).allreduce(Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.REDUCE_EDGE).withReductionFunction(new IterativeWeightVectorReduceFunction()).withDataType(MessageTypes.DOUBLE_ARRAY);
        this.trainingBuilder.setMode(this.operationMode);
        this.trainingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_TRAINING_TASK_GRAPH);
        return this.trainingBuilder.build();
    }

    private void runPredictionGraph() {
        this.iterativeSVMTestingTaskGraph = buildSvmSgdTestingTG();
        this.iterativeSVMTestingExecutionPlan = this.taskExecutor.plan(this.iterativeSVMTestingTaskGraph);
        this.taskExecutor.addInput(this.iterativeSVMTestingTaskGraph, this.iterativeSVMTestingExecutionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, "test_data", this.testingDoubleDataPointObject);
        this.taskExecutor.addInput(this.iterativeSVMTestingTaskGraph, this.iterativeSVMTestingExecutionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.INPUT_WEIGHT_VECTOR, this.inputDoubleWeightvectorObject);
        this.taskExecutor.execute(this.iterativeSVMTestingTaskGraph, this.iterativeSVMTestingExecutionPlan);
        this.finalAccuracyDoubleObject = this.taskExecutor.getOutput(this.iterativeSVMTestingTaskGraph, this.iterativeSVMTestingExecutionPlan, Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK);
        this.accuracy = ((Double) this.finalAccuracyDoubleObject.getPartitions()[0].getConsumer().next()).doubleValue();
        LOG.info(String.format("Final Accuracy : %f ", Double.valueOf(this.accuracy)));
    }

    private ComputeGraph buildSvmSgdTestingTG() {
        this.iterativePredictionDataStreamer = new IterativePredictionDataStreamer(this.svmJobParameters.getFeatures(), this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        this.iterativeSVMAccuracyReduce = new IterativeSVMAccuracyReduce(this.operationMode);
        this.testingBuilder.addSource(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, this.iterativePredictionDataStreamer, this.dataStreamerParallelism);
        this.testingBuilder.addCompute(Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK, this.iterativeSVMAccuracyReduce, this.dataStreamerParallelism).allreduce(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK).viaEdge(Constants.SimpleGraphConfig.PREDICTION_EDGE).withReductionFunction(new IterativeAccuracyReduceFunction()).withDataType(MessageTypes.DOUBLE);
        this.testingBuilder.setMode(this.operationMode);
        this.testingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_PREDICTION_TASK_GRAPH);
        return this.testingBuilder.build();
    }

    private void generateSummary() {
        double d = Runtime.getRuntime().totalMemory() / 1048576.0d;
        convert2Seconds();
        this.totalDTime = this.initializingDTime + this.dataLoadingDTime + this.trainingDTime + this.testingDTime;
        LOG.info(String.format((((((((((((("\n\n======================================================================================\n") + "\t\t\tIterative SVM Task Summary : [" + this.experimentName + "]\n") + "======================================================================================\n") + "Training Dataset [" + this.svmJobParameters.getTrainingDataDir() + "] \n") + "Testing  Dataset [" + this.svmJobParameters.getTestingDataDir() + "] \n") + "Total Memory [ " + (Runtime.getRuntime().totalMemory() / 1048576.0d) + " MB] \n") + "Maximum Memory [ " + d + " MB] \n") + "Data Loading Time (Training + Testing) \t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.dataLoadingDTime)) + "  s \n") + "Training Time \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.trainingDTime)) + "  s \n") + "Testing Time  \t\t\t\t\t\t\t= " + String.format("%3.9f", Double.valueOf(this.testingDTime)) + "  s \n") + "Total Time (Data Loading Time + Training Time + Testing Time) \t=" + String.format(" %.9f", Double.valueOf(this.totalDTime)) + "  s \n") + String.format("Accuracy of the Trained Model \t\t\t\t\t= %2.9f", Double.valueOf(this.accuracy)) + " %%\n") + "======================================================================================\n", new Object[0]));
        save();
    }

    private void save() {
        new TrainedModel(this.binaryBatchModel, this.accuracy, this.trainingTime, this.svmJobParameters.getExperimentName() + "-itr-task", this.svmJobParameters.getParallelism()).saveModel(this.svmJobParameters.getModelSaveDir());
    }

    private void convert2Seconds() {
        this.initializingDTime = this.initializingTime / 1.0E9d;
        this.dataLoadingDTime = this.dataLoadingTime / 1.0E9d;
        this.trainingDTime = this.trainingTime / 1.0E9d;
        this.testingDTime = this.testingTime / 1.0E9d;
    }
}
