package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;

import com.google.common.collect.ImmutableList;
import java.io.File;
import java.util.List;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
import org.apache.hadoop.yarn.submarine.common.resource.ResourceUtils;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/yarn/submarine/client/cli/runjob/tensorflow/TestRunJobCliParsingTensorFlowYaml.class */
public class TestRunJobCliParsingTensorFlowYaml {
    private static final String OVERRIDDEN_PREFIX = "overridden_";
    private static final String DIR_NAME = "runjob-tensorflow-yaml";
    private File yamlConfig;
    private static Logger LOG = LoggerFactory.getLogger(TestRunJobCliParsingTensorFlowYaml.class);

    @Rule
    public ExpectedException exception = ExpectedException.none();

    @Before
    public void before() {
        SubmarineLogs.verboseOff();
    }

    @After
    public void after() {
        YamlConfigTestUtils.deleteFile(this.yamlConfig);
    }

    private void verifyBasicConfigValues(RunJobParameters runJobParameters) {
        verifyBasicConfigValues(runJobParameters, ImmutableList.of("env1=env1Value", "env2=env2Value"));
    }

    private void verifyBasicConfigValues(RunJobParameters runJobParameters, List<String> list) {
        Assert.assertEquals("testInputPath", runJobParameters.getInputPath());
        Assert.assertEquals("testCheckpointPath", runJobParameters.getCheckpointPath());
        Assert.assertEquals("testDockerImage", runJobParameters.getDockerImageName());
        Assert.assertNotNull(runJobParameters.getLocalizations());
        Assert.assertEquals(2L, runJobParameters.getLocalizations().size());
        Assert.assertNotNull(runJobParameters.getQuicklinks());
        Assert.assertEquals(2L, runJobParameters.getQuicklinks().size());
        Assert.assertTrue(SubmarineLogs.isVerbose());
        Assert.assertTrue(runJobParameters.isWaitJobFinish());
        for (String str : list) {
            Assert.assertTrue(String.format("%s should be in env list of jobRunParameters!", str), runJobParameters.getEnvars().contains(str));
        }
    }

    private void verifyPsValues(RunJobParameters runJobParameters, String str) {
        Assert.assertTrue(RunJobParameters.class + " must be an instance of " + TensorFlowRunJobParameters.class, runJobParameters instanceof TensorFlowRunJobParameters);
        TensorFlowRunJobParameters tensorFlowRunJobParameters = (TensorFlowRunJobParameters) runJobParameters;
        Assert.assertEquals(4L, tensorFlowRunJobParameters.getNumPS());
        Assert.assertEquals(str + "testLaunchCmdPs", tensorFlowRunJobParameters.getPSLaunchCmd());
        Assert.assertEquals(str + "testDockerImagePs", tensorFlowRunJobParameters.getPsDockerImage());
        Assert.assertEquals(Resources.createResource(20500, 34), tensorFlowRunJobParameters.getPsResource());
    }

    private TensorFlowRunJobParameters verifyWorkerCommonValues(RunJobParameters runJobParameters, String str) {
        Assert.assertTrue(RunJobParameters.class + " must be an instance of " + TensorFlowRunJobParameters.class, runJobParameters instanceof TensorFlowRunJobParameters);
        TensorFlowRunJobParameters tensorFlowRunJobParameters = (TensorFlowRunJobParameters) runJobParameters;
        Assert.assertEquals(3L, tensorFlowRunJobParameters.getNumWorkers());
        Assert.assertEquals(str + "testLaunchCmdWorker", tensorFlowRunJobParameters.getWorkerLaunchCmd());
        Assert.assertEquals(str + "testDockerImageWorker", tensorFlowRunJobParameters.getWorkerDockerImage());
        return tensorFlowRunJobParameters;
    }

    private void verifyWorkerValues(RunJobParameters runJobParameters, String str) {
        Assert.assertEquals(Resources.createResource(20480, 32), verifyWorkerCommonValues(runJobParameters, str).getWorkerResource());
    }

    private void verifyWorkerValuesWithGpu(RunJobParameters runJobParameters, String str) {
        TensorFlowRunJobParameters verifyWorkerCommonValues = verifyWorkerCommonValues(runJobParameters, str);
        Resource createResource = Resources.createResource(20480, 32);
        ResourceUtils.setResource(createResource, "yarn.io/gpu", 2);
        Assert.assertEquals(createResource, verifyWorkerCommonValues.getWorkerResource());
    }

    private void verifySecurityValues(RunJobParameters runJobParameters) {
        Assert.assertEquals("keytabPath", runJobParameters.getKeytab());
        Assert.assertEquals("testPrincipal", runJobParameters.getPrincipal());
        Assert.assertTrue(runJobParameters.isDistributeKeytab());
    }

    private void verifyTensorboardValues(RunJobParameters runJobParameters) {
        Assert.assertTrue(RunJobParameters.class + " must be an instance of " + TensorFlowRunJobParameters.class, runJobParameters instanceof TensorFlowRunJobParameters);
        TensorFlowRunJobParameters tensorFlowRunJobParameters = (TensorFlowRunJobParameters) runJobParameters;
        Assert.assertTrue(tensorFlowRunJobParameters.isTensorboardEnabled());
        Assert.assertEquals("tensorboardDockerImage", tensorFlowRunJobParameters.getTensorboardDockerImage());
        Assert.assertEquals(Resources.createResource(21000, 37), tensorFlowRunJobParameters.getTensorboardResource());
    }

    @Test
    public void testValidYamlParsing() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        Assert.assertFalse(SubmarineLogs.isVerbose());
        this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/valid-config.yaml");
        runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
        RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        verifyBasicConfigValues(runJobParameters);
        verifyPsValues(runJobParameters, "");
        verifyWorkerValues(runJobParameters, "");
        verifySecurityValues(runJobParameters);
        verifyTensorboardValues(runJobParameters);
    }

    @Test
    public void testValidGpuYamlParsing() throws Exception {
        try {
            ResourceUtils.configureResourceType("yarn.io/gpu");
            RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
            Assert.assertFalse(SubmarineLogs.isVerbose());
            this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/valid-gpu-config.yaml");
            runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
            RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
            verifyBasicConfigValues(runJobParameters);
            verifyPsValues(runJobParameters, "");
            verifyWorkerValuesWithGpu(runJobParameters, "");
            verifySecurityValues(runJobParameters);
            verifyTensorboardValues(runJobParameters);
        } catch (SubmarineRuntimeException e) {
            LOG.info("The hadoop dependency doesn't support gpu resource, so just skip this test case.");
        }
    }

    @Test
    public void testRoleOverrides() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        Assert.assertFalse(SubmarineLogs.isVerbose());
        this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/valid-config-with-overrides.yaml");
        runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
        RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        verifyBasicConfigValues(runJobParameters);
        verifyPsValues(runJobParameters, OVERRIDDEN_PREFIX);
        verifyWorkerValues(runJobParameters, OVERRIDDEN_PREFIX);
        verifySecurityValues(runJobParameters);
        verifyTensorboardValues(runJobParameters);
    }

    @Test
    public void testMissingPrincipalUnderSecuritySection() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/security-principal-is-missing.yaml");
        runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
        RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        verifyBasicConfigValues(runJobParameters);
        verifyPsValues(runJobParameters, "");
        verifyWorkerValues(runJobParameters, "");
        verifyTensorboardValues(runJobParameters);
        Assert.assertEquals("keytabPath", runJobParameters.getKeytab());
        Assert.assertNull("Principal should be null!", runJobParameters.getPrincipal());
        Assert.assertTrue(runJobParameters.isDistributeKeytab());
    }

    @Test
    public void testMissingTensorBoardDockerImage() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/tensorboard-dockerimage-is-missing.yaml");
        runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
        TensorFlowRunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        verifyBasicConfigValues(runJobParameters);
        verifyPsValues(runJobParameters, "");
        verifyWorkerValues(runJobParameters, "");
        verifySecurityValues(runJobParameters);
        TensorFlowRunJobParameters tensorFlowRunJobParameters = runJobParameters;
        Assert.assertTrue(tensorFlowRunJobParameters.isTensorboardEnabled());
        Assert.assertNull("tensorboardDockerImage should be null!", tensorFlowRunJobParameters.getTensorboardDockerImage());
        Assert.assertEquals(Resources.createResource(21000, 37), tensorFlowRunJobParameters.getTensorboardResource());
    }

    @Test
    public void testMissingEnvs() throws Exception {
        RunJobCli runJobCli = new RunJobCli(TestRunJobCliParsingCommon.getMockClientContext());
        this.yamlConfig = YamlConfigTestUtils.createTempFileWithContents("runjob-tensorflow-yaml/envs-are-missing.yaml");
        runJobCli.run(new String[]{"-f", this.yamlConfig.getAbsolutePath(), "--verbose"});
        RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
        verifyBasicConfigValues(runJobParameters, ImmutableList.of());
        verifyPsValues(runJobParameters, "");
        verifyWorkerValues(runJobParameters, "");
        verifySecurityValues(runJobParameters);
        verifyTensorboardValues(runJobParameters);
    }
}
