package org.neo4j.gds.ml.pipeline;

import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.TimeUtil;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.pipeline.FeatureStep;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/TrainingPipeline.class */
public abstract class TrainingPipeline<FEATURE_STEP extends FeatureStep> implements Pipeline<FEATURE_STEP> {
    protected final List<ExecutableNodePropertyStep> nodePropertySteps = new ArrayList();
    protected final List<FEATURE_STEP> featureSteps = new ArrayList();
    private final ZonedDateTime creationTime = TimeUtil.now();
    protected Map<TrainingMethod, List<TunableTrainerConfig>> trainingParameterSpace = new EnumMap(TrainingMethod.class);
    protected AutoTuningConfig autoTuningConfig = AutoTuningConfig.DEFAULT_CONFIG;

    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/TrainingPipeline$TrainingType.class */
    protected enum TrainingType {
        CLASSIFICATION { // from class: org.neo4j.gds.ml.pipeline.TrainingPipeline.TrainingType.1
            @Override // org.neo4j.gds.ml.pipeline.TrainingPipeline.TrainingType
            List<TrainingMethod> supportedMethods() {
                return List.of(TrainingMethod.LogisticRegression, TrainingMethod.RandomForestClassification, TrainingMethod.MLPClassification);
            }
        },
        REGRESSION { // from class: org.neo4j.gds.ml.pipeline.TrainingPipeline.TrainingType.2
            @Override // org.neo4j.gds.ml.pipeline.TrainingPipeline.TrainingType
            List<TrainingMethod> supportedMethods() {
                return List.of(TrainingMethod.LinearRegression, TrainingMethod.RandomForestRegression);
            }
        };

        abstract List<TrainingMethod> supportedMethods();
    }

    public static Map<String, List<Map<String, Object>>> toMapParameterSpace(Map<TrainingMethod, List<TunableTrainerConfig>> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return ((TrainingMethod) entry.getKey()).toString();
        }, entry2 -> {
            return (List) ((List) entry2.getValue()).stream().map((v0) -> {
                return v0.toMap();
            }).collect(Collectors.toList());
        }));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TrainingPipeline(TrainingType trainingType) {
        trainingType.supportedMethods().forEach(trainingMethod -> {
            this.trainingParameterSpace.put(trainingMethod, new ArrayList());
        });
    }

    public Map<String, Object> toMap() {
        HashMap hashMap = new HashMap();
        hashMap.put("featurePipeline", featurePipelineDescription());
        hashMap.put("trainingParameterSpace", toMapParameterSpace(this.trainingParameterSpace));
        hashMap.put("autoTuningConfig", autoTuningConfig().toMap());
        hashMap.putAll(additionalEntries());
        return hashMap;
    }

    public abstract String type();

    protected abstract Map<String, List<Map<String, Object>>> featurePipelineDescription();

    protected abstract Map<String, Object> additionalEntries();

    private int numberOfTrainerConfigs() {
        return trainingParameterSpace().values().stream().mapToInt((v0) -> {
            return v0.size();
        }).sum();
    }

    public void addNodePropertyStep(ExecutableNodePropertyStep executableNodePropertyStep) {
        validateUniqueMutateProperty(executableNodePropertyStep);
        this.nodePropertySteps.add(executableNodePropertyStep);
    }

    public void addFeatureStep(FEATURE_STEP feature_step) {
        this.featureSteps.add(feature_step);
    }

    @Override // org.neo4j.gds.ml.pipeline.Pipeline
    public List<ExecutableNodePropertyStep> nodePropertySteps() {
        return this.nodePropertySteps;
    }

    @Override // org.neo4j.gds.ml.pipeline.Pipeline
    public List<FEATURE_STEP> featureSteps() {
        return this.featureSteps;
    }

    public Map<TrainingMethod, List<TunableTrainerConfig>> trainingParameterSpace() {
        return this.trainingParameterSpace;
    }

    private int concreteTrainerConfigsCount() {
        return (int) trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).filter((v0) -> {
            return v0.isConcrete();
        }).count();
    }

    public int numberOfModelSelectionTrials() {
        int concreteTrainerConfigsCount = concreteTrainerConfigsCount();
        return concreteTrainerConfigsCount == numberOfTrainerConfigs() ? numberOfTrainerConfigs() : autoTuningConfig().maxTrials() + concreteTrainerConfigsCount;
    }

    public void addTrainerConfig(TunableTrainerConfig tunableTrainerConfig) {
        this.trainingParameterSpace.get(tunableTrainerConfig.trainingMethod()).add(tunableTrainerConfig);
    }

    public void addTrainerConfig(TrainerConfig trainerConfig) {
        this.trainingParameterSpace.get(trainerConfig.method()).add(trainerConfig.toTunableConfig());
    }

    public AutoTuningConfig autoTuningConfig() {
        return this.autoTuningConfig;
    }

    public void setAutoTuningConfig(AutoTuningConfig autoTuningConfig) {
        this.autoTuningConfig = autoTuningConfig;
    }

    public void validateTrainingParameterSpace() {
        if (numberOfModelSelectionTrials() == 0) {
            throw new IllegalArgumentException("Need at least one model candidate for training.");
        }
    }

    private void validateUniqueMutateProperty(ExecutableNodePropertyStep executableNodePropertyStep) {
        this.nodePropertySteps.forEach(executableNodePropertyStep2 -> {
            String mutateNodeProperty = executableNodePropertyStep.mutateNodeProperty();
            if (mutateNodeProperty.equals(executableNodePropertyStep2.mutateNodeProperty())) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("The value of `%s` is expected to be unique, but %s was already specified in the %s procedure.", new Object[]{"mutateProperty", mutateNodeProperty, executableNodePropertyStep2.procName()}));
            }
        });
    }

    public ZonedDateTime creationTime() {
        return this.creationTime;
    }
}
