package org.apache.tez.mapreduce.input;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Random;
import java.util.UUID;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.mapreduce.hadoop.MRInputHelpers;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.apache.tez.runtime.library.api.KeyValueReader;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/tez/mapreduce/input/TestMultiMRInput.class */
public class TestMultiMRInput {
    private static final Logger LOG = LoggerFactory.getLogger(TestMultiMRInput.class);
    private static final JobConf defaultConf = new JobConf();
    private static final String testTmpDir;
    private static final Path TEST_ROOT_DIR;
    private static FileSystem localFs;

    @Before
    public void setup() throws IOException {
        LOG.info("Setup. Using test dir: " + TEST_ROOT_DIR);
        localFs.delete(TEST_ROOT_DIR, true);
        localFs.mkdirs(TEST_ROOT_DIR);
    }

    @Test(timeout = 5000)
    public void test0PhysicalInputs() throws Exception {
        Path path = new Path(TEST_ROOT_DIR, "testSingleSplit");
        JobConf jobConf = new JobConf(defaultConf);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{path});
        MRRuntimeProtos.MRInputUserPayloadProto.Builder newBuilder = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder();
        newBuilder.setGroupingEnabled(false);
        newBuilder.setConfigurationBytes(TezUtils.createByteStringFromConf(jobConf));
        MultiMRInput multiMRInput = new MultiMRInput(createTezInputContext(newBuilder.build().toByteArray()), 0);
        multiMRInput.initialize();
        multiMRInput.start();
        Assert.assertEquals(0L, multiMRInput.getKeyValueReaders().size());
        try {
            multiMRInput.handleEvents(new LinkedList());
            Assert.fail("HandleEvents should cause an input with 0 physical inputs to fail");
        } catch (Exception e) {
            Assert.assertTrue(e instanceof IllegalStateException);
        }
    }

    @Test(timeout = 5000)
    public void testSingleSplit() throws Exception {
        Path path = new Path(TEST_ROOT_DIR, "testSingleSplit");
        JobConf jobConf = new JobConf(defaultConf);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{path});
        MRRuntimeProtos.MRInputUserPayloadProto.Builder newBuilder = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder();
        newBuilder.setGroupingEnabled(false);
        newBuilder.setConfigurationBytes(TezUtils.createByteStringFromConf(jobConf));
        InputContext createTezInputContext = createTezInputContext(newBuilder.build().toByteArray());
        MultiMRInput multiMRInput = new MultiMRInput(createTezInputContext, 1);
        multiMRInput.initialize();
        ArrayList arrayList = new ArrayList();
        LinkedHashMap<LongWritable, Text> createInputData = createInputData(localFs, path, jobConf, "file1", 0L, 10L);
        InputSplit[] splits = new SequenceFileInputFormat().getSplits(jobConf, 1);
        Assert.assertEquals(1L, splits.length);
        InputDataInformationEvent createWithSerializedPayload = InputDataInformationEvent.createWithSerializedPayload(0, MRInputHelpers.createSplitProto(splits[0]).toByteString().asReadOnlyByteBuffer());
        arrayList.clear();
        arrayList.add(createWithSerializedPayload);
        multiMRInput.handleEvents(arrayList);
        int i = 0;
        int i2 = 0;
        for (KeyValueReader keyValueReader : multiMRInput.getKeyValueReaders()) {
            i++;
            while (keyValueReader.next()) {
                i2++;
                ((InputContext) Mockito.verify(createTezInputContext, Mockito.times(i2))).notifyProgress();
                if (createInputData.size() == 0) {
                    Assert.fail("Found more records than expected");
                }
                Assert.assertEquals(keyValueReader.getCurrentValue(), createInputData.remove(keyValueReader.getCurrentKey()));
            }
            try {
                keyValueReader.next();
                Assert.fail();
            } catch (IOException e) {
                Assert.assertTrue(e.getMessage().contains("For usage, please refer to"));
            }
        }
        Assert.assertEquals(1L, i);
    }

    @Test(timeout = 5000)
    public void testMultipleSplits() throws Exception {
        Path path = new Path(TEST_ROOT_DIR, "testMultipleSplits");
        JobConf jobConf = new JobConf(defaultConf);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{path});
        MRRuntimeProtos.MRInputUserPayloadProto.Builder newBuilder = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder();
        newBuilder.setGroupingEnabled(false);
        newBuilder.setConfigurationBytes(TezUtils.createByteStringFromConf(jobConf));
        MultiMRInput multiMRInput = new MultiMRInput(createTezInputContext(newBuilder.build().toByteArray()), 2);
        multiMRInput.initialize();
        ArrayList arrayList = new ArrayList();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap<LongWritable, Text> createInputData = createInputData(localFs, path, jobConf, "file1", 0L, 10L);
        LinkedHashMap<LongWritable, Text> createInputData2 = createInputData(localFs, path, jobConf, "file2", 10L, 20L);
        linkedHashMap.putAll(createInputData);
        linkedHashMap.putAll(createInputData2);
        InputSplit[] splits = new SequenceFileInputFormat().getSplits(jobConf, 2);
        Assert.assertEquals(2L, splits.length);
        InputDataInformationEvent createWithSerializedPayload = InputDataInformationEvent.createWithSerializedPayload(0, MRInputHelpers.createSplitProto(splits[0]).toByteString().asReadOnlyByteBuffer());
        InputDataInformationEvent createWithSerializedPayload2 = InputDataInformationEvent.createWithSerializedPayload(0, MRInputHelpers.createSplitProto(splits[1]).toByteString().asReadOnlyByteBuffer());
        arrayList.clear();
        arrayList.add(createWithSerializedPayload);
        arrayList.add(createWithSerializedPayload2);
        multiMRInput.handleEvents(arrayList);
        int i = 0;
        for (KeyValueReader keyValueReader : multiMRInput.getKeyValueReaders()) {
            i++;
            while (keyValueReader.next()) {
                if (linkedHashMap.size() == 0) {
                    Assert.fail("Found more records than expected");
                }
                Assert.assertEquals(keyValueReader.getCurrentValue(), linkedHashMap.remove(keyValueReader.getCurrentKey()));
            }
            try {
                keyValueReader.next();
                Assert.fail();
            } catch (IOException e) {
                Assert.assertTrue(e.getMessage().contains("For usage, please refer to"));
            }
        }
        Assert.assertEquals(2L, i);
    }

    @Test(timeout = 5000)
    public void testExtraEvents() throws Exception {
        Path path = new Path(TEST_ROOT_DIR, "testExtraEvents");
        JobConf jobConf = new JobConf(defaultConf);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{path});
        MRRuntimeProtos.MRInputUserPayloadProto.Builder newBuilder = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder();
        newBuilder.setGroupingEnabled(false);
        newBuilder.setConfigurationBytes(TezUtils.createByteStringFromConf(jobConf));
        MultiMRInput multiMRInput = new MultiMRInput(createTezInputContext(newBuilder.build().toByteArray()), 1);
        multiMRInput.initialize();
        ArrayList arrayList = new ArrayList();
        createInputData(localFs, path, jobConf, "file1", 0L, 10L);
        InputSplit[] splits = new SequenceFileInputFormat().getSplits(jobConf, 1);
        Assert.assertEquals(1L, splits.length);
        MRRuntimeProtos.MRSplitProto createSplitProto = MRInputHelpers.createSplitProto(splits[0]);
        InputDataInformationEvent createWithSerializedPayload = InputDataInformationEvent.createWithSerializedPayload(0, createSplitProto.toByteString().asReadOnlyByteBuffer());
        InputDataInformationEvent createWithSerializedPayload2 = InputDataInformationEvent.createWithSerializedPayload(1, createSplitProto.toByteString().asReadOnlyByteBuffer());
        arrayList.clear();
        arrayList.add(createWithSerializedPayload);
        arrayList.add(createWithSerializedPayload2);
        try {
            multiMRInput.handleEvents(arrayList);
            Assert.fail("Expecting Exception due to too many events");
        } catch (Exception e) {
            Assert.assertTrue(e.getMessage().contains("Unexpected event. All physical sources already initialized"));
        }
    }

    private InputContext createTezInputContext(byte[] bArr) {
        ApplicationId newInstance = ApplicationId.newInstance(10000L, 1);
        TezCounters tezCounters = new TezCounters();
        InputContext inputContext = (InputContext) Mockito.mock(InputContext.class);
        ((InputContext) Mockito.doReturn(newInstance).when(inputContext)).getApplicationId();
        ((InputContext) Mockito.doReturn(tezCounters).when(inputContext)).getCounters();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getDAGAttemptNumber();
        ((InputContext) Mockito.doReturn("dagName").when(inputContext)).getDAGName();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getInputIndex();
        ((InputContext) Mockito.doReturn("srcVertexName").when(inputContext)).getSourceVertexName();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getTaskAttemptNumber();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getTaskIndex();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getTaskVertexIndex();
        ((InputContext) Mockito.doReturn(UUID.randomUUID().toString()).when(inputContext)).getUniqueIdentifier();
        ((InputContext) Mockito.doReturn("taskVertexName").when(inputContext)).getTaskVertexName();
        ((InputContext) Mockito.doReturn(UserPayload.create(ByteBuffer.wrap(bArr))).when(inputContext)).getUserPayload();
        return inputContext;
    }

    @AfterClass
    public static void cleanUp() throws IOException {
        localFs.delete(TEST_ROOT_DIR, true);
    }

    public static LinkedHashMap<LongWritable, Text> createInputData(FileSystem fileSystem, Path path, JobConf jobConf, String str, long j, long j2) throws IOException {
        LinkedHashMap<LongWritable, Text> linkedHashMap = new LinkedHashMap<>();
        Path path2 = new Path(path, str);
        LOG.info("Generating data at path: " + path2);
        SequenceFile.Writer createWriter = SequenceFile.createWriter(fileSystem, jobConf, path2, LongWritable.class, Text.class);
        try {
            Random random = new Random(System.currentTimeMillis());
            LongWritable longWritable = new LongWritable();
            Text text = new Text();
            for (long j3 = j; j3 < j2; j3++) {
                longWritable.set(j3);
                text.set(Integer.toString(random.nextInt(10000)));
                linkedHashMap.put(new LongWritable(longWritable.get()), new Text(text.toString()));
                createWriter.append(longWritable, text);
                LOG.info("<k, v> : <" + longWritable.get() + ", " + text + ">");
            }
            return linkedHashMap;
        } finally {
            createWriter.close();
        }
    }

    static {
        defaultConf.set("fs.defaultFS", "file:///");
        try {
            localFs = FileSystem.getLocal(defaultConf);
            testTmpDir = System.getProperty("test.build.data", "target");
            TEST_ROOT_DIR = new Path(testTmpDir, TestMultiMRInput.class.getSimpleName() + "-tmpDir");
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
