package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.glm.GLM;
import hex.glm.GLMModel;
import java.util.Iterator;
import java.util.ServiceLoader;
import java.util.function.Supplier;
import joptsimple.internal.Strings;
import water.exceptions.H2OIllegalArgumentException;
import water.nbhm.NonBlockingHashMap;

/* loaded from: input_file:hex/ensemble/Metalearners.class */
public class Metalearners {
    static final NonBlockingHashMap<String, MetalearnerProvider> providersByName = new NonBlockingHashMap<>();

    /* loaded from: input_file:hex/ensemble/Metalearners$AUTOMetalearner.class */
    static class AUTOMetalearner extends GLMMetalearner {
        AUTOMetalearner() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // hex.ensemble.Metalearners.GLMMetalearner, hex.ensemble.Metalearner
        public void setCustomParams(GLMModel.GLMParameters gLMParameters) {
            super.setCustomParams(gLMParameters);
            gLMParameters._non_negative = true;
            gLMParameters._standardize = false;
            if (gLMParameters._valid != null) {
                gLMParameters._lambda_search = true;
                gLMParameters._early_stopping = false;
            }
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$DLMetalearner.class */
    static class DLMetalearner extends SimpleMetalearner {
        public DLMetalearner() {
            super(Metalearner.Algorithm.deeplearning.name());
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$DRFMetalearner.class */
    static class DRFMetalearner extends SimpleMetalearner {
        public DRFMetalearner() {
            super(Metalearner.Algorithm.drf.name());
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$GBMMetalearner.class */
    static class GBMMetalearner extends SimpleMetalearner {
        public GBMMetalearner() {
            super(Metalearner.Algorithm.gbm.name());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ensemble/Metalearners$GLMMetalearner.class */
    public static class GLMMetalearner extends Metalearner<GLM, GLMModel, GLMModel.GLMParameters> {
        GLMMetalearner() {
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // hex.ensemble.Metalearner
        public GLM createBuilder() {
            return (GLM) ModelBuilder.make("GLM", this._metalearnerJob, this._metalearnerKey);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // hex.ensemble.Metalearner
        public void setCustomParams(GLMModel.GLMParameters gLMParameters) {
            if (this._model.modelCategory == ModelCategory.Regression) {
                gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            } else if (this._model.modelCategory == ModelCategory.Binomial) {
                gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            } else {
                if (this._model.modelCategory != ModelCategory.Multinomial) {
                    throw new H2OIllegalArgumentException("Family " + this._model.modelCategory + "  is not supported.");
                }
                gLMParameters._family = GLMModel.GLMParameters.Family.multinomial;
            }
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$LocalProvider.class */
    static class LocalProvider<M extends Metalearner> implements MetalearnerProvider<M> {
        private Metalearner.Algorithm _algorithm;
        private Supplier<M> _instanceFactory;

        public LocalProvider(Metalearner.Algorithm algorithm, Supplier<M> supplier) {
            this._algorithm = algorithm;
            this._instanceFactory = supplier;
        }

        @Override // hex.ensemble.MetalearnerProvider
        public String getName() {
            return this._algorithm.name();
        }

        @Override // hex.ensemble.MetalearnerProvider
        public M newInstance() {
            return this._instanceFactory.get();
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$NaiveBayesMetalearner.class */
    static class NaiveBayesMetalearner extends SimpleMetalearner {
        public NaiveBayesMetalearner() {
            super(Metalearner.Algorithm.naivebayes.name());
        }
    }

    /* loaded from: input_file:hex/ensemble/Metalearners$SimpleMetalearner.class */
    public static class SimpleMetalearner extends Metalearner {
        private String _algo;

        /* JADX INFO: Access modifiers changed from: protected */
        public SimpleMetalearner(String str) {
            this._algo = str;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // hex.ensemble.Metalearner
        public ModelBuilder createBuilder() {
            return ModelBuilder.make(this._algo, this._metalearnerJob, this._metalearnerKey);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Metalearner.Algorithm getActualMetalearnerAlgo(Metalearner.Algorithm algorithm) {
        assertAvailable(algorithm.name());
        return algorithm == Metalearner.Algorithm.AUTO ? Metalearner.Algorithm.glm : algorithm;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Model.Parameters createParameters(String str) {
        assertAvailable(str);
        return createInstance(str).createBuilder()._parms;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Metalearner createInstance(String str) {
        assertAvailable(str);
        return providersByName.get(str).newInstance();
    }

    private static void assertAvailable(String str) {
        if (!providersByName.containsKey(str)) {
            throw new H2OIllegalArgumentException(Strings.SINGLE_QUOTE + str + "' metalearner is not supported or available.");
        }
    }

    static {
        for (LocalProvider localProvider : new LocalProvider[]{new LocalProvider(Metalearner.Algorithm.AUTO, AUTOMetalearner::new), new LocalProvider(Metalearner.Algorithm.deeplearning, DLMetalearner::new), new LocalProvider(Metalearner.Algorithm.drf, DRFMetalearner::new), new LocalProvider(Metalearner.Algorithm.gbm, GBMMetalearner::new), new LocalProvider(Metalearner.Algorithm.glm, GLMMetalearner::new), new LocalProvider(Metalearner.Algorithm.naivebayes, NaiveBayesMetalearner::new)}) {
            providersByName.put(localProvider.getName(), localProvider);
        }
        Iterator it = ServiceLoader.load(MetalearnerProvider.class).iterator();
        while (it.hasNext()) {
            MetalearnerProvider metalearnerProvider = (MetalearnerProvider) it.next();
            providersByName.put(metalearnerProvider.getName(), metalearnerProvider);
        }
    }
}
