package ai.konduit.serving.configcreator.converter;

import ai.konduit.serving.configcreator.StringSplitter;
import java.util.Map;
import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.learning.config.AdaBelief;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import picocli.CommandLine;

/* loaded from: input_file:ai/konduit/serving/configcreator/converter/UpdaterConverter.class */
public class UpdaterConverter implements CommandLine.ITypeConverter<IUpdater> {
    public static final String DELIMITER = ",";
    public static final String UPDATER_TYPE_KEY = "type";

    /* loaded from: input_file:ai/konduit/serving/configcreator/converter/UpdaterConverter$UpdaterTypes.class */
    public enum UpdaterTypes {
        AMSGRAD,
        ADABELIEF,
        ADAGRAD,
        ADADELTA,
        ADAMAX,
        ADAM,
        NADAM,
        NESTEROVS,
        NOOP,
        RMSPROP,
        SGD
    }

    /* renamed from: convert, reason: merged with bridge method [inline-methods] */
    public IUpdater m12convert(String str) throws Exception {
        Map<String, String> splitResult = new StringSplitter(",").splitResult(str);
        if (!splitResult.containsKey("type")) {
            throw new IllegalArgumentException("Please specify an updater type for proper creation.");
        }
        IUpdater instanceForName = instanceForName(splitResult.get("type"));
        setValuesFor(instanceForName, splitResult);
        return instanceForName;
    }

    private void setValuesFor(IUpdater iUpdater, Map<String, String> map) throws Exception {
        for (Map.Entry<String, String> entry : map.entrySet()) {
            if (iUpdater instanceof Sgd) {
                Sgd sgd = (Sgd) iUpdater;
                if (entry.getKey().equals("learningRate")) {
                    sgd.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    sgd.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof RmsProp) {
                RmsProp rmsProp = (RmsProp) iUpdater;
                if (entry.getKey().equals("epsilon")) {
                    rmsProp.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    rmsProp.setLearningRate(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("rmsDecay")) {
                    rmsProp.setRmsDecay(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    rmsProp.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof AMSGrad) {
                AMSGrad aMSGrad = (AMSGrad) iUpdater;
                if (entry.getKey().equals("beta1")) {
                    aMSGrad.setBeta1(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("beta2")) {
                    aMSGrad.setBeta2(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    aMSGrad.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    aMSGrad.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    aMSGrad.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof AdaDelta) {
                AdaDelta adaDelta = (AdaDelta) iUpdater;
                if (entry.getKey().equals("epsilon")) {
                    adaDelta.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("rho")) {
                    adaDelta.setRho(Double.parseDouble(entry.getValue()));
                }
            } else if (iUpdater instanceof NoOp) {
            } else if (iUpdater instanceof AdaGrad) {
                AdaGrad adaGrad = (AdaGrad) iUpdater;
                if (entry.getKey().equals("learningRate")) {
                    adaGrad.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    adaGrad.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    adaGrad.setEpsilon(Double.parseDouble(entry.getValue()));
                }
            } else if (iUpdater instanceof Adam) {
                Adam adam = (Adam) iUpdater;
                if (entry.getKey().equals("beta1")) {
                    adam.setBeta1(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("beta2")) {
                    adam.setBeta2(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    adam.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    adam.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    adam.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof AdaMax) {
                AdaMax adaMax = (AdaMax) iUpdater;
                if (entry.getKey().equals("beta1")) {
                    adaMax.setBeta1(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("beta2")) {
                    adaMax.setBeta2(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    adaMax.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    adaMax.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    adaMax.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof AdaBelief) {
                AdaBelief adaBelief = (AdaBelief) iUpdater;
                if (entry.getKey().equals("beta1")) {
                    adaBelief.setBeta1(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("beta2")) {
                    adaBelief.setBeta2(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    adaBelief.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    adaBelief.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    adaBelief.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof Nesterovs) {
                Nesterovs nesterovs = (Nesterovs) iUpdater;
                if (entry.getKey().equals("learningRate")) {
                    nesterovs.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("momentum")) {
                    nesterovs.setMomentum(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    nesterovs.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
                if (entry.getKey().equals("momentumISchedule")) {
                    nesterovs.setMomentumISchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            } else if (iUpdater instanceof Nadam) {
                Nadam nadam = (Nadam) iUpdater;
                if (entry.getKey().equals("beta1")) {
                    nadam.setBeta1(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("beta2")) {
                    nadam.setBeta2(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("epsilon")) {
                    nadam.setEpsilon(Double.parseDouble(entry.getValue()));
                }
                if (entry.getKey().equals("learningRate")) {
                    nadam.setLearningRate(Double.valueOf(Double.parseDouble(entry.getValue())).doubleValue());
                }
                if (entry.getKey().equals("learningRateSchedule")) {
                    nadam.setLearningRateSchedule(new LearningRateScheduleConverter().m6convert(entry.getValue()));
                }
            }
        }
    }

    private IUpdater instanceForName(String str) {
        switch (UpdaterTypes.valueOf(str.toUpperCase())) {
            case SGD:
                return new Sgd();
            case ADAM:
                return new Adam();
            case NOOP:
                return new NoOp();
            case NADAM:
                return new Nadam();
            case ADAMAX:
                return new AdaMax();
            case ADAGRAD:
                return new AdaGrad();
            case AMSGRAD:
                return new AMSGrad();
            case RMSPROP:
                return new RmsProp();
            case ADADELTA:
                return new AdaDelta();
            case ADABELIEF:
                return new AdaBelief();
            case NESTEROVS:
                return new Nesterovs();
            default:
                throw new IllegalArgumentException("Illegal type " + str);
        }
    }
}
