package edu.iu.dsc.tws.tsched.batch.batchscheduler;

import edu.iu.dsc.tws.api.compute.exceptions.TaskSchedulerException;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.Edge;
import edu.iu.dsc.tws.api.compute.graph.Vertex;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.modifiers.Receptor;
import edu.iu.dsc.tws.api.compute.schedule.ITaskScheduler;
import edu.iu.dsc.tws.api.compute.schedule.elements.Resource;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstanceId;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstancePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskSchedulePlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.Worker;
import edu.iu.dsc.tws.api.compute.schedule.elements.WorkerPlan;
import edu.iu.dsc.tws.api.compute.schedule.elements.WorkerSchedulePlan;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.tsched.spi.common.TaskSchedulerContext;
import edu.iu.dsc.tws.tsched.spi.taskschedule.TaskInstanceMapCalculation;
import edu.iu.dsc.tws.tsched.utils.TaskAttributes;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;
import java.util.stream.IntStream;

/* loaded from: input_file:edu/iu/dsc/tws/tsched/batch/batchscheduler/BatchTaskScheduler.class */
public class BatchTaskScheduler implements ITaskScheduler {
    private Double instanceRAM;
    private Double instanceDisk;
    private Double instanceCPU;
    private Config config;
    private TaskAttributes taskAttributes;
    private int workerId;
    private int index;
    private Map<Integer, List<TaskInstanceId>> batchTaskAllocation;
    private static final Logger LOG = Logger.getLogger(BatchTaskScheduler.class.getName());
    private static Map<String, Integer> receivableNameMap = new LinkedHashMap();
    private static Map<String, Integer> collectibleNameMap = new LinkedHashMap();
    private boolean dependentGraphs = false;
    private List<Integer> workerIdList = new ArrayList();
    private Map<String, TaskSchedulePlan> taskSchedulePlanMap = new LinkedHashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/iu/dsc/tws/tsched/batch/batchscheduler/BatchTaskScheduler$VertexComparator.class */
    public static class VertexComparator implements Comparator<Vertex> {
        private VertexComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Vertex vertex, Vertex vertex2) {
            return vertex.getName().compareTo(vertex2.getName());
        }
    }

    public void initialize(Config config) {
        this.config = config;
        this.instanceRAM = Double.valueOf(TaskSchedulerContext.taskInstanceRam(this.config));
        this.instanceDisk = Double.valueOf(TaskSchedulerContext.taskInstanceDisk(this.config));
        this.instanceCPU = Double.valueOf(TaskSchedulerContext.taskInstanceCpu(this.config));
        this.batchTaskAllocation = new LinkedHashMap();
        this.taskAttributes = new TaskAttributes();
    }

    public void initialize(Config config, int i) {
        initialize(config);
        this.workerId = i;
    }

    public Map<String, TaskSchedulePlan> schedule(WorkerPlan workerPlan, ComputeGraph... computeGraphArr) {
        if (computeGraphArr.length > 1) {
            addReceptorsCollectors(computeGraphArr);
            validateParallelism();
            this.dependentGraphs = true;
            for (ComputeGraph computeGraph : computeGraphArr) {
                this.taskSchedulePlanMap.put(computeGraph.getGraphName(), schedule(computeGraph, workerPlan));
            }
        } else {
            this.taskSchedulePlanMap.put(computeGraphArr[0].getGraphName(), schedule(computeGraphArr[0], workerPlan));
        }
        return this.taskSchedulePlanMap;
    }

    private void addReceptorsCollectors(ComputeGraph... computeGraphArr) {
        for (ComputeGraph computeGraph : computeGraphArr) {
            Iterator it = new LinkedHashSet(computeGraph.getTaskVertexSet()).iterator();
            while (it.hasNext()) {
                Vertex vertex = (Vertex) it.next();
                Collector task = vertex.getTask();
                if (task instanceof Receptor) {
                    if (((Receptor) task).getReceivableNames() != null) {
                        ((Receptor) task).getReceivableNames().forEach(str -> {
                            receivableNameMap.put(str, Integer.valueOf(vertex.getParallelism()));
                        });
                    }
                } else if ((task instanceof Collector) && task.getCollectibleNames() != null) {
                    task.getCollectibleNames().forEach(str2 -> {
                        collectibleNameMap.put(str2, Integer.valueOf(vertex.getParallelism()));
                    });
                }
            }
        }
    }

    public TaskSchedulePlan schedule(ComputeGraph computeGraph, WorkerPlan workerPlan) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet(computeGraph.getTaskVertexSet());
        Map<Integer, List<TaskInstanceId>> batchSchedulingAlgorithm = batchSchedulingAlgorithm(computeGraph, workerPlan.getNumberOfWorkers());
        TaskInstanceMapCalculation taskInstanceMapCalculation = new TaskInstanceMapCalculation(this.instanceRAM, this.instanceCPU, this.instanceDisk);
        Map<Integer, Map<TaskInstanceId, Double>> instancesRamMapInContainer = taskInstanceMapCalculation.getInstancesRamMapInContainer(batchSchedulingAlgorithm, linkedHashSet2);
        Map<Integer, Map<TaskInstanceId, Double>> instancesDiskMapInContainer = taskInstanceMapCalculation.getInstancesDiskMapInContainer(batchSchedulingAlgorithm, linkedHashSet2);
        Map<Integer, Map<TaskInstanceId, Double>> instancesCPUMapInContainer = taskInstanceMapCalculation.getInstancesCPUMapInContainer(batchSchedulingAlgorithm, linkedHashSet2);
        Iterator<Integer> it = batchSchedulingAlgorithm.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double containerRamPadding = TaskSchedulerContext.containerRamPadding(this.config);
            double containerDiskPadding = TaskSchedulerContext.containerDiskPadding(this.config);
            double containerCpuPadding = TaskSchedulerContext.containerCpuPadding(this.config);
            List<TaskInstanceId> list = batchSchedulingAlgorithm.get(Integer.valueOf(intValue));
            HashMap hashMap = new HashMap();
            for (TaskInstanceId taskInstanceId : list) {
                double doubleValue = instancesRamMapInContainer.get(Integer.valueOf(intValue)).get(taskInstanceId).doubleValue();
                double doubleValue2 = instancesDiskMapInContainer.get(Integer.valueOf(intValue)).get(taskInstanceId).doubleValue();
                hashMap.put(taskInstanceId, new TaskInstancePlan(taskInstanceId.getTaskName(), taskInstanceId.getTaskId(), taskInstanceId.getTaskIndex(), new Resource(Double.valueOf(doubleValue), Double.valueOf(doubleValue2), Double.valueOf(instancesCPUMapInContainer.get(Integer.valueOf(intValue)).get(taskInstanceId).doubleValue()))));
                containerRamPadding += doubleValue;
                containerDiskPadding += doubleValue2;
                containerCpuPadding += doubleValue2;
            }
            Worker worker = workerPlan.getWorker(intValue);
            linkedHashSet.add(new WorkerSchedulePlan(intValue, new LinkedHashSet(hashMap.values()), (worker == null || worker.getCpu() <= 0 || worker.getDisk() <= 0 || worker.getRam() <= 0) ? new Resource(Double.valueOf(containerRamPadding), Double.valueOf(containerDiskPadding), Double.valueOf(containerCpuPadding)) : new Resource(Double.valueOf(worker.getRam()), Double.valueOf(worker.getDisk()), Double.valueOf(worker.getCpu()))));
            if (this.dependentGraphs && this.index == 0) {
                this.workerIdList.add(Integer.valueOf(intValue));
            }
        }
        this.index++;
        TaskSchedulePlan taskSchedulePlan = new TaskSchedulePlan(0, linkedHashSet);
        if (this.workerId == 0) {
            for (Map.Entry entry : taskSchedulePlan.getContainersMap().entrySet()) {
                Integer num = (Integer) entry.getKey();
                Set<TaskInstancePlan> taskInstances = ((WorkerSchedulePlan) entry.getValue()).getTaskInstances();
                LOG.fine("Graph Name:" + computeGraph.getGraphName() + "\tcontainer id:" + num);
                for (TaskInstancePlan taskInstancePlan : taskInstances) {
                    LOG.fine("Task Id:" + taskInstancePlan.getTaskId() + "\tIndex" + taskInstancePlan.getTaskIndex() + "\tName:" + taskInstancePlan.getTaskName());
                }
            }
        }
        return taskSchedulePlan;
    }

    private Map<Integer, List<TaskInstanceId>> batchSchedulingAlgorithm(ComputeGraph computeGraph, int i) throws TaskSchedulerException {
        LinkedHashSet<Vertex> linkedHashSet = new LinkedHashSet(computeGraph.getTaskVertexSet());
        new TreeSet(new VertexComparator()).addAll(linkedHashSet);
        IntStream.range(0, i).forEach(i2 -> {
            this.batchTaskAllocation.put(Integer.valueOf(i2), new ArrayList());
        });
        int i3 = 0;
        if (this.dependentGraphs) {
            for (Vertex vertex : linkedHashSet) {
                if (vertex.getTask() instanceof Receptor) {
                    validateReceptor(computeGraph, vertex);
                }
                dependentTaskWorkerAllocation(computeGraph, vertex, i, i3);
                i3++;
            }
        } else {
            for (Vertex vertex2 : linkedHashSet) {
                Collector task = vertex2.getTask();
                if (task instanceof Collector) {
                    task.getCollectibleNames().forEach(str -> {
                        collectibleNameMap.put(str, Integer.valueOf(vertex2.getParallelism()));
                    });
                } else if (task instanceof Receptor) {
                    ((Receptor) task).getReceivableNames().forEach(str2 -> {
                        receivableNameMap.put(str2, Integer.valueOf(vertex2.getParallelism()));
                    });
                    validateParallelism();
                }
                independentTaskWorkerAllocation(computeGraph, vertex2, i, i3);
                i3++;
            }
        }
        return this.batchTaskAllocation;
    }

    private void validateParallelism() {
        for (Map.Entry<String, Integer> entry : receivableNameMap.entrySet()) {
            if (collectibleNameMap.containsKey(entry.getKey()) && collectibleNameMap.get(entry.getKey()).intValue() != entry.getValue().intValue()) {
                throw new Twister2RuntimeException("Please verify the dependent collector(s) and receptor(s) parallelism values which are not equal");
            }
        }
    }

    private void dependentTaskWorkerAllocation(ComputeGraph computeGraph, Vertex vertex, int i, int i2) {
        int i3;
        int size;
        if (computeGraph.getNodeConstraints().isEmpty()) {
            int totalNumberOfInstances = this.taskAttributes.getTotalNumberOfInstances(vertex);
            String name = vertex.getName();
            for (int i4 = 0; i4 < totalNumberOfInstances; i4++) {
                if (this.workerIdList.size() == 0) {
                    i3 = i4;
                    size = i;
                } else {
                    i3 = i4;
                    size = this.workerIdList.size();
                }
                this.batchTaskAllocation.get(Integer.valueOf(i3 % size)).add(new TaskInstanceId(name, i2, i4));
            }
            return;
        }
        int totalNumberOfInstances2 = this.taskAttributes.getTotalNumberOfInstances(vertex, computeGraph.getNodeConstraints());
        int instancesPerWorker = this.taskAttributes.getInstancesPerWorker(computeGraph.getGraphConstraints());
        int i5 = 0;
        for (int i6 = 0; i6 < totalNumberOfInstances2; i6++) {
            int size2 = this.workerIdList.size() == 0 ? i6 % i : i6 % this.workerIdList.size();
            if (i5 >= instancesPerWorker) {
                throw new TaskSchedulerException("Task Scheduling couldn't be possible for the presentconfiguration, please check the number of workers maximum instances per worker");
            }
            this.batchTaskAllocation.get(Integer.valueOf(size2)).add(new TaskInstanceId(vertex.getName(), i2, i6));
            i5++;
        }
    }

    private void independentTaskWorkerAllocation(ComputeGraph computeGraph, Vertex vertex, int i, int i2) {
        int totalNumberOfInstances = !computeGraph.getNodeConstraints().isEmpty() ? this.taskAttributes.getTotalNumberOfInstances(vertex, computeGraph.getNodeConstraints()) : this.taskAttributes.getTotalNumberOfInstances(vertex);
        if (computeGraph.getNodeConstraints().isEmpty()) {
            String name = vertex.getName();
            for (int i3 = 0; i3 < totalNumberOfInstances; i3++) {
                this.batchTaskAllocation.get(Integer.valueOf(i3 % i)).add(new TaskInstanceId(name, i2, i3));
            }
            return;
        }
        int instancesPerWorker = this.taskAttributes.getInstancesPerWorker(computeGraph.getGraphConstraints());
        int i4 = 0;
        for (int i5 = 0; i5 < totalNumberOfInstances; i5++) {
            int i6 = i5 % i;
            if (i4 >= instancesPerWorker) {
                throw new TaskSchedulerException("Task Scheduling couldn't be possible for the presentconfiguration, please check the number of workers, maximum instances per worker");
            }
            this.batchTaskAllocation.get(Integer.valueOf(i6)).add(new TaskInstanceId(vertex.getName(), i2, i5));
            i4++;
        }
    }

    private void validateReceptor(ComputeGraph computeGraph, Vertex vertex) {
        Iterator it = computeGraph.outEdges(vertex).iterator();
        while (it.hasNext()) {
            Vertex childOfTask = computeGraph.childOfTask(vertex, ((Edge) it.next()).getName());
            if ((childOfTask.getTask() instanceof Collector) && childOfTask.getParallelism() != vertex.getParallelism()) {
                throw new TaskSchedulerException("Specify the same parallelism for parent and child tasks which depends on the input from the parent in" + computeGraph.getGraphName() + " graph");
            }
        }
    }
}
