package org.opensearch.ml.common;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.HashMap;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

/* loaded from: input_file:org/opensearch/ml/common/MLCommonsClassLoader.class */
public class MLCommonsClassLoader {
    private static final Logger logger = LogManager.getLogger(MLCommonsClassLoader.class);
    private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap();
    private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap();

    public static void loadClassMapping() {
        loadMLAlgoParameterClassMapping();
        loadMLOutputClassMapping();
        loadMLInputDataSetClassMapping();
        loadExecuteInputClassMapping();
        loadExecuteOutputClassMapping();
    }

    private static void loadMLAlgoParameterClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.input.parameter", new Scanner[0]);
        for (Class<?> cls : reflections.getTypesAnnotatedWith(MLAlgoParameter.class)) {
            FunctionName[] algorithms = ((MLAlgoParameter) cls.getAnnotation(MLAlgoParameter.class)).algorithms();
            if (algorithms != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    parameterClassMap.put(functionName, cls);
                }
            }
        }
        for (Class<?> cls2 : reflections.getTypesAnnotatedWith(MLAlgoOutput.class)) {
            MLOutputType value = ((MLAlgoOutput) cls2.getAnnotation(MLAlgoOutput.class)).value();
            if (value != null) {
                parameterClassMap.put(value, cls2);
            }
        }
    }

    private static void loadMLOutputClassMapping() {
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.output", new Scanner[0]).getTypesAnnotatedWith(MLAlgoOutput.class)) {
            MLOutputType value = ((MLAlgoOutput) cls.getAnnotation(MLAlgoOutput.class)).value();
            if (value != null) {
                parameterClassMap.put(value, cls);
            }
        }
    }

    private static void loadMLInputDataSetClassMapping() {
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.dataset", new Scanner[0]).getTypesAnnotatedWith(InputDataSet.class)) {
            MLInputDataType value = ((InputDataSet) cls.getAnnotation(InputDataSet.class)).value();
            if (value != null) {
                parameterClassMap.put(value, cls);
            }
        }
    }

    private static void loadExecuteInputClassMapping() {
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.input.execute", new Scanner[0]).getTypesAnnotatedWith(ExecuteInput.class)) {
            FunctionName[] algorithms = ((ExecuteInput) cls.getAnnotation(ExecuteInput.class)).algorithms();
            if (algorithms != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    executeInputClassMap.put(functionName, cls);
                }
            }
        }
    }

    private static void loadExecuteOutputClassMapping() {
        for (Class<?> cls : new Reflections("org.opensearch.ml.common.output.execute", new Scanner[0]).getTypesAnnotatedWith(ExecuteOutput.class)) {
            FunctionName[] algorithms = ((ExecuteOutput) cls.getAnnotation(ExecuteOutput.class)).algorithms();
            if (algorithms != null && algorithms.length > 0) {
                for (FunctionName functionName : algorithms) {
                    executeOutputClassMap.put(functionName, cls);
                }
            }
        }
    }

    public static <T extends Enum<T>, S, I> S initMLInstance(T t, I i, Class<?> cls) {
        return (S) init(parameterClassMap, t, i, cls);
    }

    public static <T extends Enum<T>, S, I> S initExecuteInputInstance(T t, I i, Class<?> cls) {
        return (S) init(executeInputClassMap, t, i, cls);
    }

    public static <T extends Enum<T>, S, I> S initExecuteOutputInstance(T t, I i, Class<?> cls) {
        return (S) init(executeOutputClassMap, t, i, cls);
    }

    private static <T extends Enum<T>, S, I> S init(Map<Enum<?>, Class<?>> map, T t, I i, Class<?> cls) {
        Class<?> cls2 = map.get(t);
        if (cls2 == null) {
            throw new IllegalArgumentException("Can't find class for type " + t);
        }
        try {
            return (S) cls2.getConstructor(cls).newInstance(i);
        } catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw ((MLException) cause);
            }
            logger.error("Failed to init instance for type " + t, e);
            return null;
        }
    }

    static {
        try {
            AccessController.doPrivileged(() -> {
                loadClassMapping();
                return null;
            });
        } catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML commons", e);
        }
    }
}
