package org.opensearch.ml.client;

import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import lombok.Generated;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Constants;
import org.opensearch.ml.common.input.InputHelper;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

/* loaded from: input_file:org/opensearch/ml/client/MachineLearningNodeClient.class */
public class MachineLearningNodeClient implements MachineLearningClient {
    private final Client client;

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void predict(String str, MLInput mLInput, ActionListener<MLOutput> actionListener) {
        validateMLInput(mLInput, true);
        this.client.execute(MLPredictionTaskAction.INSTANCE, MLPredictionTaskRequest.builder().mlInput(mLInput).modelId(str).dispatchTask(true).build(), getMlPredictionTaskResponseActionListener(actionListener));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void trainAndPredict(MLInput mLInput, ActionListener<MLOutput> actionListener) {
        validateMLInput(mLInput, true);
        this.client.execute(MLTrainAndPredictionTaskAction.INSTANCE, MLTrainingTaskRequest.builder().mlInput(mLInput).dispatchTask(true).build(), getMlPredictionTaskResponseActionListener(actionListener));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void train(MLInput mLInput, boolean z, ActionListener<MLOutput> actionListener) {
        validateMLInput(mLInput, true);
        this.client.execute(MLTrainingTaskAction.INSTANCE, MLTrainingTaskRequest.builder().mlInput(mLInput).async(z).dispatchTask(true).build(), getMlPredictionTaskResponseActionListener(actionListener));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void run(MLInput mLInput, Map<String, Object> map, ActionListener<MLOutput> actionListener) {
        String action = InputHelper.getAction(map);
        if (action == null) {
            throw new IllegalArgumentException("The parameter action is required.");
        }
        FunctionName functionName = InputHelper.getFunctionName(map);
        MLAlgoParams convertArgumentToMLParameter = InputHelper.convertArgumentToMLParameter(map, functionName);
        mLInput.setAlgorithm(functionName);
        mLInput.setParameters(convertArgumentToMLParameter);
        boolean z = -1;
        switch (action.hashCode()) {
            case -1809263830:
                if (action.equals(Constants.TRAINANDPREDICT)) {
                    z = 2;
                    break;
                }
                break;
            case -318720807:
                if (action.equals(Constants.PREDICT)) {
                    z = true;
                    break;
                }
                break;
            case 110621192:
                if (action.equals(Constants.TRAIN)) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                train(mLInput, map.containsKey(Constants.ASYNC) ? ((Boolean) map.get(Constants.ASYNC)).booleanValue() : false, actionListener);
                return;
            case true:
                String str = (String) map.get(Constants.MODELID);
                if (str == null) {
                    throw new IllegalArgumentException("The model ID is required for prediction.");
                }
                predict(str, mLInput, actionListener);
                return;
            case true:
                trainAndPredict(mLInput, actionListener);
                return;
            default:
                throw new IllegalArgumentException("Unsupported action.");
        }
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void getModel(String str, ActionListener<MLModel> actionListener) {
        this.client.execute(MLModelGetAction.INSTANCE, MLModelGetRequest.builder().modelId(str).build(), getMlGetModelResponseActionListener(actionListener));
    }

    private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(ActionListener<MLModel> actionListener) {
        CheckedConsumer checkedConsumer = mLModelGetResponse -> {
            actionListener.onResponse(mLModelGetResponse.getMlModel());
        };
        Objects.requireNonNull(actionListener);
        return wrapActionListener(ActionListener.wrap(checkedConsumer, actionListener::onFailure), actionResponse -> {
            return MLModelGetResponse.fromActionResponse(actionResponse);
        });
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void deleteModel(String str, ActionListener<DeleteResponse> actionListener) {
        MLModelDeleteRequest build = MLModelDeleteRequest.builder().modelId(str).build();
        Client client = this.client;
        MLModelDeleteAction mLModelDeleteAction = MLModelDeleteAction.INSTANCE;
        CheckedConsumer checkedConsumer = deleteResponse -> {
            actionListener.onResponse(deleteResponse);
        };
        Objects.requireNonNull(actionListener);
        client.execute(mLModelDeleteAction, build, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> actionListener) {
        Client client = this.client;
        MLModelSearchAction mLModelSearchAction = MLModelSearchAction.INSTANCE;
        CheckedConsumer checkedConsumer = searchResponse -> {
            actionListener.onResponse(searchResponse);
        };
        Objects.requireNonNull(actionListener);
        client.execute(mLModelSearchAction, searchRequest, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void getTask(String str, ActionListener<MLTask> actionListener) {
        MLTaskGetRequest build = MLTaskGetRequest.builder().taskId(str).build();
        Client client = this.client;
        MLTaskGetAction mLTaskGetAction = MLTaskGetAction.INSTANCE;
        CheckedConsumer checkedConsumer = mLTaskGetResponse -> {
            actionListener.onResponse(MLTaskGetResponse.fromActionResponse(mLTaskGetResponse).getMlTask());
        };
        Objects.requireNonNull(actionListener);
        client.execute(mLTaskGetAction, build, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void deleteTask(String str, ActionListener<DeleteResponse> actionListener) {
        MLTaskDeleteRequest build = MLTaskDeleteRequest.builder().taskId(str).build();
        Client client = this.client;
        MLTaskDeleteAction mLTaskDeleteAction = MLTaskDeleteAction.INSTANCE;
        CheckedConsumer checkedConsumer = deleteResponse -> {
            actionListener.onResponse(deleteResponse);
        };
        Objects.requireNonNull(actionListener);
        client.execute(mLTaskDeleteAction, build, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Override // org.opensearch.ml.client.MachineLearningClient
    public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> actionListener) {
        Client client = this.client;
        MLTaskSearchAction mLTaskSearchAction = MLTaskSearchAction.INSTANCE;
        CheckedConsumer checkedConsumer = searchResponse -> {
            actionListener.onResponse(searchResponse);
        };
        Objects.requireNonNull(actionListener);
        client.execute(mLTaskSearchAction, searchRequest, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> actionListener) {
        CheckedConsumer checkedConsumer = mLTaskResponse -> {
            actionListener.onResponse(mLTaskResponse.getOutput());
        };
        Objects.requireNonNull(actionListener);
        return wrapActionListener(ActionListener.wrap(checkedConsumer, actionListener::onFailure), actionResponse -> {
            return MLTaskResponse.fromActionResponse(actionResponse);
        });
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> actionListener, Function<ActionResponse, T> function) {
        return ActionListener.wrap(actionResponse -> {
            actionListener.onResponse((ActionResponse) function.apply(actionResponse));
        }, exc -> {
            actionListener.onFailure(exc);
        });
    }

    private void validateMLInput(MLInput mLInput, boolean z) {
        if (mLInput == null) {
            throw new IllegalArgumentException("ML Input can't be null");
        }
        if (z && mLInput.getInputDataset() == null) {
            throw new IllegalArgumentException("input data set can't be null");
        }
    }

    @Generated
    public MachineLearningNodeClient(Client client) {
        this.client = client;
    }
}
