package org.deeplearning4j.aws.ec2.provision;

import java.io.File;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.aws.ec2.Ec2BoxCreator;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/aws/ec2/provision/ClusterSetup.class */
public class ClusterSetup {

    @Option(name = "-sg", usage = "security group, this needs to be set")
    private String securityGroupName;

    @Option(name = "-kp", usage = "key pair name, also needs to be set.")
    private String keyPairName;

    @Option(name = "-kpath", usage = "path to private key - needs to be set, this is used to login to amazon.")
    private String pathToPrivateKey;

    @Option(name = "-wscript", usage = "path to worker script to run, this will allow customization of dependencies")
    private String workerSetupScriptPath;

    @Option(name = "-mscript", usage = "path to master script to run this will allow customization of the dependencies")
    private String masterSetupScriptPath;

    @Option(name = "-lib", usage = "path to lib directory, this could be a default dl4j distribution or your own custom dependencies")
    private String libDirPath;

    @Option(name = "-datapath", usage = "path to serialized dataset")
    private String dataSetPath;
    private static Logger log = LoggerFactory.getLogger(ClusterSetup.class);

    @Option(name = "-w", usage = "Number of workers")
    private int numWorkers = 1;

    @Option(name = "-ami", usage = "Amazon machine image: default, amazon linux (only works with RHEL right now")
    private String ami = "ami-bba18dd2";

    @Option(name = "-s", usage = "size of instance: default m1.medium")
    private String size = "m1.medium";

    @Option(name = "-uploddeps", usage = "whether to uploade deps: default true")
    private boolean uploadDeps = true;

    public ClusterSetup(String[] strArr) {
        CmdLineParser cmdLineParser = new CmdLineParser(this);
        try {
            cmdLineParser.parseArgument(strArr);
        } catch (CmdLineException e) {
            cmdLineParser.printUsage(System.err);
            log.error("Unable to parse args", e);
        }
    }

    public void exec() {
        Ec2BoxCreator ec2BoxCreator = new Ec2BoxCreator(this.ami, this.numWorkers + 1, this.size, this.securityGroupName, this.keyPairName);
        ec2BoxCreator.create();
        ec2BoxCreator.blockTillAllRunning();
        provisionMaster(ec2BoxCreator.getHosts().get(0));
    }

    private void provisionMaster(String str) {
        try {
            HostProvisioner hostProvisioner = new HostProvisioner(str, "ec2-user");
            hostProvisioner.addKeyFile(this.pathToPrivateKey);
            hostProvisioner.uploadForDeployment(this.libDirPath, "lib");
            if (this.dataSetPath != null) {
                hostProvisioner.uploadForDeployment(this.dataSetPath, "");
            }
            hostProvisioner.uploadForDeployment(this.pathToPrivateKey, "/home/ec2-user/.ssh/" + new File(this.pathToPrivateKey).getName());
            hostProvisioner.runRemoteCommand("chmod 0400 /home/ec2-user/.ssh/*");
            hostProvisioner.uploadAndRun(this.masterSetupScriptPath, "");
        } catch (Exception e) {
            log.error("Error ", e);
        }
    }

    private void provisionWorkers(List<String> list) {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            try {
                HostProvisioner hostProvisioner = new HostProvisioner(it.next(), "ec2-user");
                hostProvisioner.addKeyFile(this.pathToPrivateKey);
                hostProvisioner.uploadForDeployment(this.libDirPath, "lib");
                hostProvisioner.uploadAndRun(this.workerSetupScriptPath, "");
            } catch (Exception e) {
                log.error("Error ", e);
            }
        }
    }

    public static void main(String[] strArr) {
        new ClusterSetup(strArr).exec();
    }
}
