package org.apache.hadoop.yarn.submarine.client.cli.runjob;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/hadoop/yarn/submarine/client/cli/runjob/TestRunJobCliParsingParameterized.class */
public class TestRunJobCliParsingParameterized {
    private final Framework framework;

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @Before
    public void before() {
        SubmarineLogs.verboseOff();
    }

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Object[]{Framework.TENSORFLOW});
        arrayList.add(new Object[]{Framework.PYTORCH});
        return arrayList;
    }

    public TestRunJobCliParsingParameterized(Framework framework) {
        this.framework = framework;
    }

    private String getFrameworkName() {
        return this.framework.getValue();
    }

    @Test
    public void testPrintHelp() {
        new RunJobCli(new MockClientContext(), (JobSubmitter) Mockito.mock(JobSubmitter.class), (JobMonitor) Mockito.mock(JobMonitor.class)).printUsages();
    }

    @Test
    public void testNoInputPathOptionSpecified() throws Exception {
        String str = "";
        try {
            new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext()).run(new String[]{"--framework", getFrameworkName(), "--name", "my-job", "--docker_image", "tf-docker:1.1.0", "--checkpoint_path", "hdfs://output", "--num_workers", "1", "--worker_launch_cmd", "python run-job.py", "--worker_resources", "memory=4g,vcores=2", "--verbose", "--wait_job_finish"});
        } catch (ParseException e) {
            str = e.getMessage();
            e.printStackTrace();
        }
        Assert.assertEquals("\"--input_path\" is absent", str);
    }

    @Test
    public void testJobWithoutName() throws Exception {
        String str = "";
        try {
            new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext()).run(new String[]{"--framework", getFrameworkName(), "--docker_image", "tf-docker:1.1.0", "--num_workers", "0", "--verbose"});
        } catch (ParseException e) {
            str = e.getMessage();
            e.printStackTrace();
        }
        Assert.assertEquals("--name is absent", str);
    }

    @Test
    public void testLaunchCommandPatternReplace() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        Assert.assertFalse(SubmarineLogs.isVerbose());
        ArrayList newArrayList = Lists.newArrayList(new String[]{"--framework", getFrameworkName(), "--name", "my-job", "--docker_image", "tf-docker:1.1.0", "--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output", "--num_workers", "3", "--worker_launch_cmd", "python run-job.py --input=%input_path% --model_dir=%checkpoint_path% --export_dir=%saved_model_path%/savedmodel", "--worker_resources", "memory=2048,vcores=2"});
        if (this.framework == Framework.TENSORFLOW) {
            newArrayList.addAll(Lists.newArrayList(new String[]{"--ps_resources", "memory=4096,vcores=4", "--ps_launch_cmd", "python run-ps.py --input=%input_path% --model_dir=%checkpoint_path%/model", "--verbose"}));
        }
        runJobCli.run((String[]) newArrayList.toArray(new String[0]));
        TensorFlowRunJobParameters checkExpectedFrameworkParams = checkExpectedFrameworkParams(runJobCli);
        if (this.framework == Framework.TENSORFLOW) {
            TensorFlowRunJobParameters tensorFlowRunJobParameters = checkExpectedFrameworkParams;
            Assert.assertEquals("python run-job.py --input=hdfs://input --model_dir=hdfs://output --export_dir=hdfs://output/savedmodel", tensorFlowRunJobParameters.getWorkerLaunchCmd());
            Assert.assertEquals("python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model", tensorFlowRunJobParameters.getPSLaunchCmd());
        } else if (this.framework == Framework.PYTORCH) {
            Assert.assertEquals("python run-job.py --input=hdfs://input --model_dir=hdfs://output --export_dir=hdfs://output/savedmodel", ((PyTorchRunJobParameters) checkExpectedFrameworkParams).getWorkerLaunchCmd());
        }
    }

    private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
        RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        if (this.framework == Framework.TENSORFLOW) {
            Assert.assertTrue(RunJobParameters.class + " must be an instance of " + TensorFlowRunJobParameters.class, runJobParameters instanceof TensorFlowRunJobParameters);
        } else if (this.framework == Framework.PYTORCH) {
            Assert.assertTrue(RunJobParameters.class + " must be an instance of " + PyTorchRunJobParameters.class, runJobParameters instanceof PyTorchRunJobParameters);
        }
        return runJobParameters;
    }
}
