package edu.iu.dsc.tws.examples.internal.taskgraph;

import edu.iu.dsc.tws.api.JobConfig;
import edu.iu.dsc.tws.api.Twister2Job;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.nodes.BaseCompute;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.dataset.DataPartition;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.batch.cdfw.CDFConstants;
import edu.iu.dsc.tws.examples.batch.kmeans.KMeansUtils;
import edu.iu.dsc.tws.examples.comms.Constants;
import edu.iu.dsc.tws.rsched.core.ResourceAllocator;
import edu.iu.dsc.tws.rsched.job.Twister2Submitter;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample.class */
public class MultiComputeTasksGraphExample extends TaskWorker {
    private static final Logger LOG = Logger.getLogger(MultiComputeTasksGraphExample.class.getName());
    private int parallelismValue = 0;

    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample$Aggregator.class */
    public class Aggregator implements IFunction {
        private static final long serialVersionUID = -254264120110286748L;

        public Aggregator() {
        }

        public Object onMessage(Object obj, Object obj2) throws ArrayIndexOutOfBoundsException {
            double[] dArr = (double[]) obj;
            double[] dArr2 = (double[]) obj2;
            double[] dArr3 = new double[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr3[i] = dArr[i] + dArr2[i];
            }
            return dArr3;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample$FirstComputeTask.class */
    private static class FirstComputeTask extends BaseCompute {
        private static final long serialVersionUID = -254264120110286748L;
        private int dsize;
        private double[] cal;

        private FirstComputeTask() {
            this.dsize = 100;
            this.cal = new double[this.dsize];
        }

        public boolean execute(IMessage iMessage) {
            MultiComputeTasksGraphExample.LOG.log(Level.INFO, "First Compute Received Data: " + this.context.getWorkerId() + ":" + this.context.globalTaskId());
            if (iMessage.getContent() instanceof Iterator) {
                while (((Iterator) iMessage.getContent()).hasNext()) {
                    this.cal = (double[]) ((Iterator) iMessage.getContent()).next();
                    for (int i = 0; i < this.cal.length; i++) {
                        this.cal[i] = this.cal[i] * 2.0d;
                    }
                    this.context.write("freduce", this.cal);
                }
            }
            this.context.end("freduce");
            return true;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample$ReduceTask.class */
    private static class ReduceTask extends BaseCompute implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[] newValues;

        private ReduceTask() {
        }

        public boolean execute(IMessage iMessage) {
            MultiComputeTasksGraphExample.LOG.log(Level.INFO, "Received Data from workerId: " + this.context.getWorkerId() + ":" + this.context.globalTaskId() + ":" + iMessage.getContent());
            this.newValues = (double[]) iMessage.getContent();
            return true;
        }

        public DataPartition<double[]> get() {
            return new EntityPartition(this.context.taskIndex(), this.newValues);
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample$SecondComputeTask.class */
    private static class SecondComputeTask extends BaseCompute {
        private static final long serialVersionUID = -254264120110286748L;
        private int dsize;
        private double[] cal;

        private SecondComputeTask() {
            this.dsize = 100;
            this.cal = new double[this.dsize];
        }

        public boolean execute(IMessage iMessage) {
            MultiComputeTasksGraphExample.LOG.log(Level.INFO, "Second Compute Received Data: " + this.context.getWorkerId() + ":" + this.context.globalTaskId());
            if (iMessage.getContent() instanceof Iterator) {
                while (((Iterator) iMessage.getContent()).hasNext()) {
                    this.cal = (double[]) ((Iterator) iMessage.getContent()).next();
                    for (int i = 0; i < this.cal.length; i++) {
                        this.cal[i] = this.cal[i] / 4.0d;
                    }
                    this.context.write("sreduce", this.cal);
                }
            }
            this.context.end("sreduce");
            return true;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/internal/taskgraph/MultiComputeTasksGraphExample$SourceTask.class */
    private static class SourceTask extends BaseSource {
        private static final long serialVersionUID = -254264120110286748L;
        private double[] datapoints;
        private int numPoints;

        private SourceTask() {
            this.datapoints = null;
            this.numPoints = 100;
        }

        public void execute() {
            this.datapoints = new double[this.numPoints];
            Random random = new Random(100);
            for (int i = 0; i < this.numPoints; i++) {
                this.datapoints[i] = random.nextDouble();
            }
            this.context.write("fdirect", this.datapoints);
            this.context.write("sdirect", this.datapoints);
            this.context.end("fdirect");
            this.context.end("sdirect");
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    public void execute() {
        LOG.log(Level.INFO, "Task worker starting: " + this.workerId);
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(this.config);
        int parseInt = Integer.parseInt((String) this.config.get(CDFConstants.ARGS_PARALLELISM_VALUE));
        SourceTask sourceTask = new SourceTask();
        FirstComputeTask firstComputeTask = new FirstComputeTask();
        SecondComputeTask secondComputeTask = new SecondComputeTask();
        ReduceTask reduceTask = new ReduceTask();
        String str = ((String) this.config.get(CDFConstants.ARGS_DINPUT)) + this.workerId;
        String str2 = ((String) this.config.get(CDFConstants.ARGS_CINPUT)) + this.workerId;
        int parseInt2 = Integer.parseInt((String) this.config.get("dim"));
        int parseInt3 = Integer.parseInt((String) this.config.get(Constants.ARGS_NUMBER_OF_FILES));
        int parseInt4 = Integer.parseInt((String) this.config.get(CDFConstants.ARGS_DSIZE));
        int parseInt5 = Integer.parseInt((String) this.config.get(CDFConstants.ARGS_CSIZE));
        LOG.info("Input Values:" + str + str2 + parseInt2 + parseInt3);
        KMeansUtils.generateDataPoints(this.config, parseInt2, parseInt3, parseInt4, parseInt5, str, str2);
        new HashMap();
        new HashMap();
        new HashMap();
        newBuilder.addSource("source", sourceTask, parseInt);
        ComputeConnection addCompute = newBuilder.addCompute("firstcompute", firstComputeTask, parseInt);
        ComputeConnection addCompute2 = newBuilder.addCompute("secondcompute", secondComputeTask, parseInt);
        ComputeConnection addCompute3 = newBuilder.addCompute("compute", reduceTask, parseInt);
        addCompute.direct("source").viaEdge("fdirect").withDataType(MessageTypes.OBJECT);
        addCompute2.direct("source").viaEdge("sdirect").withDataType(MessageTypes.OBJECT);
        addCompute3.allreduce("firstcompute").viaEdge("freduce").withReductionFunction(new Aggregator()).withDataType(MessageTypes.OBJECT).connect().allreduce("secondcompute").viaEdge("sreduce").withReductionFunction(new Aggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.addGraphConstraints("twister2.max.task.instances.per.worker", "4");
        ComputeGraph build = newBuilder.build();
        LOG.info("%%% Graph Constraints:%%%" + build.getGraphConstraints());
        this.taskExecutor.execute(build, this.taskExecutor.plan(build));
    }

    public static void main(String[] strArr) throws ParseException {
        LOG.log(Level.INFO, "MultiComputeTaskGraph");
        Config loadConfig = ResourceAllocator.loadConfig(new HashMap());
        HashMap hashMap = new HashMap();
        hashMap.put("twister2.exector.worker.threads", 8);
        Options options = new Options();
        options.addOption(CDFConstants.ARGS_PARALLELISM_VALUE, true, CDFConstants.ARGS_PARALLELISM_VALUE);
        options.addOption("workers", true, "workers");
        options.addOption(CDFConstants.ARGS_DSIZE, true, CDFConstants.ARGS_DSIZE);
        options.addOption("dim", true, "dim");
        options.addOption(CDFConstants.ARGS_CSIZE, true, CDFConstants.ARGS_CSIZE);
        options.addOption(CDFConstants.ARGS_DINPUT, true, CDFConstants.ARGS_DINPUT);
        options.addOption(CDFConstants.ARGS_CINPUT, true, CDFConstants.ARGS_CINPUT);
        options.addOption("filesys", true, "filesys");
        options.addOption(Constants.ARGS_NUMBER_OF_FILES, true, Constants.ARGS_NUMBER_OF_FILES);
        CommandLine parse = new DefaultParser().parse(options, strArr);
        int parseInt = Integer.parseInt(parse.getOptionValue("workers"));
        int parseInt2 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_PARALLELISM_VALUE));
        int parseInt3 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_DSIZE));
        int parseInt4 = Integer.parseInt(parse.getOptionValue(CDFConstants.ARGS_CSIZE));
        int parseInt5 = Integer.parseInt(parse.getOptionValue("dim"));
        int parseInt6 = Integer.parseInt(parse.getOptionValue(Constants.ARGS_NUMBER_OF_FILES));
        String optionValue = parse.getOptionValue("filesys");
        String optionValue2 = parse.getOptionValue(CDFConstants.ARGS_DINPUT);
        String optionValue3 = parse.getOptionValue(CDFConstants.ARGS_CINPUT);
        JobConfig jobConfig = new JobConfig();
        jobConfig.put("workers", Integer.toString(parseInt));
        jobConfig.put(CDFConstants.ARGS_PARALLELISM_VALUE, Integer.toString(parseInt2));
        jobConfig.put(CDFConstants.ARGS_DSIZE, Integer.toString(parseInt3));
        jobConfig.put(CDFConstants.ARGS_CSIZE, Integer.toString(parseInt4));
        jobConfig.put(CDFConstants.ARGS_DINPUT, optionValue2);
        jobConfig.put(CDFConstants.ARGS_CINPUT, optionValue3);
        jobConfig.put("filesys", optionValue);
        jobConfig.put(Constants.ARGS_NUMBER_OF_FILES, Integer.toString(parseInt6));
        jobConfig.put("dim", Integer.toString(parseInt5));
        jobConfig.putAll(hashMap);
        Twister2Job.Twister2JobBuilder newBuilder = Twister2Job.newBuilder();
        newBuilder.setJobName("MultiComputeTasksGraph");
        newBuilder.setWorkerClass(MultiComputeTasksGraphExample.class.getName());
        newBuilder.addComputeResource(2.0d, 512, 1.0d, parseInt);
        newBuilder.setConfig(jobConfig);
        Twister2Submitter.submitJob(newBuilder.build(), loadConfig);
    }
}
