package edu.iu.dsc.tws.examples.ml.svm.job;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.dataset.DataPartitionConsumer;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.ReduceAggregator;
import edu.iu.dsc.tws.examples.ml.svm.aggregate.SVMReduce;
import edu.iu.dsc.tws.examples.ml.svm.compute.SVMCompute;
import edu.iu.dsc.tws.examples.ml.svm.constant.Constants;
import edu.iu.dsc.tws.examples.ml.svm.streamer.DataStreamer;
import edu.iu.dsc.tws.examples.ml.svm.util.BinaryBatchModel;
import edu.iu.dsc.tws.examples.ml.svm.util.DataUtils;
import edu.iu.dsc.tws.examples.ml.svm.util.SVMJobParameters;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/job/SvmSgdRunner.class */
public class SvmSgdRunner extends TaskWorker {
    private static final Logger LOG = Logger.getLogger(SvmSgdRunner.class.getName());
    private final int reduceParallelism = 1;
    private int dataStreamerParallelism = 4;
    private int svmComputeParallelism = 4;
    private int features = 10;
    private OperationMode operationMode;
    private SVMJobParameters svmJobParameters;
    private BinaryBatchModel binaryBatchModel;

    public void execute() {
        initializeParameters();
        initializeExecute();
    }

    public void initializeParameters() {
        this.svmJobParameters = SVMJobParameters.build(this.config);
        this.binaryBatchModel = new BinaryBatchModel();
        this.binaryBatchModel.setIterations(this.svmJobParameters.getIterations());
        this.binaryBatchModel.setAlpha(this.svmJobParameters.getAlpha());
        this.binaryBatchModel.setFeatures(this.svmJobParameters.getFeatures());
        this.binaryBatchModel.setSamples(this.svmJobParameters.getSamples());
        this.binaryBatchModel.setW(DataUtils.seedDoubleArray(this.svmJobParameters.getFeatures()));
        LOG.info(this.binaryBatchModel.toString());
    }

    public void initializeExecute() {
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(this.config);
        this.operationMode = this.svmJobParameters.isStreaming() ? OperationMode.STREAMING : OperationMode.BATCH;
        DataStreamer dataStreamer = new DataStreamer(this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
        SVMCompute sVMCompute = new SVMCompute(this.binaryBatchModel, this.operationMode);
        SVMReduce sVMReduce = new SVMReduce(this.operationMode);
        newBuilder.addSource(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE, dataStreamer, this.dataStreamerParallelism);
        ComputeConnection addCompute = newBuilder.addCompute(Constants.SimpleGraphConfig.SVM_COMPUTE, sVMCompute, this.svmComputeParallelism);
        ComputeConnection addCompute2 = newBuilder.addCompute(Constants.SimpleGraphConfig.SVM_REDUCE, sVMReduce, 1);
        addCompute.direct(Constants.SimpleGraphConfig.DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.DATA_EDGE).withDataType(MessageTypes.OBJECT);
        addCompute2.reduce(Constants.SimpleGraphConfig.SVM_COMPUTE).viaEdge(Constants.SimpleGraphConfig.REDUCE_EDGE).withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(this.operationMode);
        ComputeGraph build = newBuilder.build();
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.execute(build, plan);
        LOG.info("Task Graph Executed !!! ");
        if (this.operationMode.equals(OperationMode.BATCH)) {
            DataPartitionConsumer consumer = this.taskExecutor.getOutput(build, plan, Constants.SimpleGraphConfig.SVM_REDUCE).getPartitions()[0].getConsumer();
            while (consumer.hasNext()) {
                LOG.info("Final Aggregated Values Are:" + Arrays.toString((double[]) consumer.next()));
            }
        }
    }
}
