package org.apache.tez.runtime;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.ProcessorDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.AbstractLogicalIOProcessor;
import org.apache.tez.runtime.api.AbstractLogicalInput;
import org.apache.tez.runtime.api.AbstractLogicalOutput;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.api.LogicalOutput;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.api.ObjectRegistry;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.api.Reader;
import org.apache.tez.runtime.api.Writer;
import org.apache.tez.runtime.api.impl.InputSpec;
import org.apache.tez.runtime.api.impl.OutputSpec;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezUmbilical;
import org.apache.tez.runtime.common.resources.ScalingAllocator;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/tez/runtime/TestLogicalIOProcessorRuntimeTask.class */
public class TestLogicalIOProcessorRuntimeTask {

    /* loaded from: input_file:org/apache/tez/runtime/TestLogicalIOProcessorRuntimeTask$TestInput.class */
    public static class TestInput extends AbstractLogicalInput {
        public static volatile int startCount = 0;
        public static volatile int vertexParallelism;

        public TestInput(InputContext inputContext, int i) {
            super(inputContext, i);
        }

        public List<Event> initialize() throws Exception {
            getContext().requestInitialMemory(0L, (MemoryUpdateCallback) null);
            getContext().inputIsReady();
            return null;
        }

        public void start() throws Exception {
            startCount++;
            vertexParallelism = getContext().getVertexParallelism();
            System.err.println("In started");
        }

        public Reader getReader() throws Exception {
            return null;
        }

        public void handleEvents(List<Event> list) throws Exception {
        }

        public List<Event> close() throws Exception {
            return null;
        }
    }

    /* loaded from: input_file:org/apache/tez/runtime/TestLogicalIOProcessorRuntimeTask$TestOutput.class */
    public static class TestOutput extends AbstractLogicalOutput {
        public static volatile int startCount = 0;
        public static volatile int vertexParallelism;

        public TestOutput(OutputContext outputContext, int i) {
            super(outputContext, i);
        }

        public List<Event> initialize() throws Exception {
            getContext().requestInitialMemory(0L, (MemoryUpdateCallback) null);
            return null;
        }

        public void start() throws Exception {
            System.err.println("Out started");
            startCount++;
            vertexParallelism = getContext().getVertexParallelism();
        }

        public Writer getWriter() throws Exception {
            return null;
        }

        public void handleEvents(List<Event> list) {
        }

        public List<Event> close() throws Exception {
            return null;
        }
    }

    /* loaded from: input_file:org/apache/tez/runtime/TestLogicalIOProcessorRuntimeTask$TestProcessor.class */
    public static class TestProcessor extends AbstractLogicalIOProcessor {
        public static volatile int runCount = 0;

        public TestProcessor(ProcessorContext processorContext) {
            super(processorContext);
        }

        public void initialize() throws Exception {
        }

        public void run(Map<String, LogicalInput> map, Map<String, LogicalOutput> map2) throws Exception {
            runCount++;
        }

        public void handleEvents(List<Event> list) {
        }

        public void close() throws Exception {
        }
    }

    @Test(timeout = 5000)
    public void testAutoStart() throws Exception {
        TezVertexID createTezVertexId = createTezVertexId(createTezDagId());
        HashMap hashMap = new HashMap();
        HashMultimap create = HashMultimap.create();
        TezUmbilical tezUmbilical = (TezUmbilical) Mockito.mock(TezUmbilical.class);
        TezConfiguration tezConfiguration = new TezConfiguration();
        tezConfiguration.set("tez.task.scale.memory.allocator.class", ScalingAllocator.class.getName());
        TaskSpec createTaskSpec = createTaskSpec(createTaskAttemptID(createTezVertexId, 1), "dag1", "vertex1", 30);
        TaskSpec createTaskSpec2 = createTaskSpec(createTaskAttemptID(createTezVertexId, 2), "dag2", "vertex1", 10);
        LogicalIOProcessorRuntimeTask logicalIOProcessorRuntimeTask = new LogicalIOProcessorRuntimeTask(createTaskSpec, 0, tezConfiguration, (String[]) null, tezUmbilical, hashMap, create, (ObjectRegistry) null);
        logicalIOProcessorRuntimeTask.initialize();
        logicalIOProcessorRuntimeTask.run();
        logicalIOProcessorRuntimeTask.close();
        Assert.assertEquals(1L, TestProcessor.runCount);
        Assert.assertEquals(1L, TestInput.startCount);
        Assert.assertEquals(0L, TestOutput.startCount);
        Assert.assertEquals(30L, TestInput.vertexParallelism);
        Assert.assertEquals(0L, TestOutput.vertexParallelism);
        Assert.assertEquals(30L, logicalIOProcessorRuntimeTask.getProcessorContext().getVertexParallelism());
        Assert.assertEquals(30L, ((InputContext) logicalIOProcessorRuntimeTask.getInputContexts().iterator().next()).getVertexParallelism());
        Assert.assertEquals(30L, ((OutputContext) logicalIOProcessorRuntimeTask.getOutputContexts().iterator().next()).getVertexParallelism());
        LogicalIOProcessorRuntimeTask logicalIOProcessorRuntimeTask2 = new LogicalIOProcessorRuntimeTask(createTaskSpec2, 0, tezConfiguration, (String[]) null, tezUmbilical, hashMap, create, (ObjectRegistry) null);
        logicalIOProcessorRuntimeTask2.initialize();
        logicalIOProcessorRuntimeTask2.run();
        logicalIOProcessorRuntimeTask2.close();
        Assert.assertEquals(2L, TestProcessor.runCount);
        Assert.assertEquals(1L, TestInput.startCount);
        Assert.assertEquals(0L, TestOutput.startCount);
        Assert.assertEquals(30L, TestInput.vertexParallelism);
        Assert.assertEquals(0L, TestOutput.vertexParallelism);
        Assert.assertEquals(10L, logicalIOProcessorRuntimeTask2.getProcessorContext().getVertexParallelism());
        Assert.assertEquals(10L, ((InputContext) logicalIOProcessorRuntimeTask2.getInputContexts().iterator().next()).getVertexParallelism());
        Assert.assertEquals(10L, ((OutputContext) logicalIOProcessorRuntimeTask2.getOutputContexts().iterator().next()).getVertexParallelism());
    }

    private TaskSpec createTaskSpec(TezTaskAttemptID tezTaskAttemptID, String str, String str2, int i) {
        return new TaskSpec(tezTaskAttemptID, str, str2, i, createProcessorDescriptor(), createInputSpecList(), createOutputSpecList(), (List) null);
    }

    private List<InputSpec> createInputSpecList() {
        return Lists.newArrayList(new InputSpec[]{new InputSpec("inedge", InputDescriptor.create(TestInput.class.getName()), 1)});
    }

    private List<OutputSpec> createOutputSpecList() {
        return Lists.newArrayList(new OutputSpec[]{new OutputSpec("outedge", OutputDescriptor.create(TestOutput.class.getName()), 1)});
    }

    private ProcessorDescriptor createProcessorDescriptor() {
        return ProcessorDescriptor.create(TestProcessor.class.getName());
    }

    private TezTaskAttemptID createTaskAttemptID(TezVertexID tezVertexID, int i) {
        return TezTaskAttemptID.getInstance(TezTaskID.getInstance(tezVertexID, i), i);
    }

    private TezVertexID createTezVertexId(TezDAGID tezDAGID) {
        return TezVertexID.getInstance(tezDAGID, 1);
    }

    private TezDAGID createTezDagId() {
        return TezDAGID.getInstance("2000", 100, 1);
    }
}
