package io.cdap.mmds.data;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.proto.BadRequestException;
import io.cdap.mmds.proto.ConflictException;
import io.cdap.mmds.proto.CreateModelRequest;
import io.cdap.mmds.proto.ExperimentNotFoundException;
import io.cdap.mmds.proto.ModelNotFoundException;
import io.cdap.mmds.proto.SplitNotFoundException;
import io.cdap.mmds.proto.TrainModelRequest;
import io.cdap.mmds.stats.CategoricalHisto;
import io.cdap.mmds.stats.NumericHisto;
import io.cdap.mmds.stats.NumericStats;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;

/* loaded from: input_file:lib/mmds-model-1.6.0.jar:io/cdap/mmds/data/ExperimentStore.class */
public class ExperimentStore {
    private static final Set<Schema.Type> CATEGORICAL_TYPES = ImmutableSet.of(Schema.Type.BOOLEAN, Schema.Type.STRING);
    private static final Set<Schema.Type> NUMERIC_TYPES = ImmutableSet.of(Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE);
    private final ExperimentMetaTable experiments;
    private final DataSplitTable splits;
    private final ModelTable models;

    public ExperimentStore(ExperimentMetaTable experimentMetaTable, DataSplitTable dataSplitTable, ModelTable modelTable) {
        this.experiments = experimentMetaTable;
        this.splits = dataSplitTable;
        this.models = modelTable;
    }

    public ExperimentsMeta listExperiments(int i, int i2, SortInfo sortInfo) {
        return this.experiments.list(i, i2);
    }

    public ExperimentsMeta listExperiments(int i, int i2, Predicate<Experiment> predicate, SortInfo sortInfo) {
        return this.experiments.list(i, i2, predicate, sortInfo);
    }

    public Experiment getExperiment(String str) {
        Experiment experiment = this.experiments.get(str);
        if (experiment == null) {
            throw new ExperimentNotFoundException(str);
        }
        return experiment;
    }

    public ExperimentStats getExperimentStats(String str) {
        Experiment experiment = getExperiment(str);
        HashMap hashMap = new HashMap();
        CategoricalHisto categoricalHisto = new CategoricalHisto();
        CategoricalHisto categoricalHisto2 = new CategoricalHisto();
        List<ModelMeta> models = listModels(str, 0, Integer.MAX_VALUE, new SortInfo(SortType.ASC)).getModels();
        if (models.isEmpty()) {
            return new ExperimentStats(experiment, hashMap, new ColumnStats(categoricalHisto), new ColumnStats(categoricalHisto2));
        }
        Iterator<ModelMeta> it = models.iterator();
        ModelMeta next = it.next();
        categoricalHisto.update(next.getAlgorithm());
        categoricalHisto2.update(next.getStatus() == null ? null : next.getStatus().toString());
        EvaluationMetrics evaluationMetrics = next.getEvaluationMetrics();
        NumericStats numericStats = new NumericStats(evaluationMetrics.getRmse());
        NumericStats numericStats2 = new NumericStats(evaluationMetrics.getR2());
        NumericStats numericStats3 = new NumericStats(evaluationMetrics.getMae());
        NumericStats numericStats4 = new NumericStats(evaluationMetrics.getEvariance());
        NumericStats numericStats5 = new NumericStats(evaluationMetrics.getPrecision());
        NumericStats numericStats6 = new NumericStats(evaluationMetrics.getRecall());
        NumericStats numericStats7 = new NumericStats(evaluationMetrics.getF1());
        while (it.hasNext()) {
            ModelMeta next2 = it.next();
            categoricalHisto.update(next2.getAlgorithm());
            categoricalHisto2.update(next2.getStatus() == null ? null : next2.getStatus().toString());
            EvaluationMetrics evaluationMetrics2 = next2.getEvaluationMetrics();
            numericStats.update(evaluationMetrics2.getRmse());
            numericStats2.update(evaluationMetrics2.getR2());
            numericStats3.update(evaluationMetrics2.getMae());
            numericStats4.update(evaluationMetrics2.getEvariance());
            numericStats5.update(evaluationMetrics2.getPrecision());
            numericStats6.update(evaluationMetrics2.getRecall());
            numericStats7.update(evaluationMetrics2.getF1());
        }
        Iterator<ModelMeta> it2 = models.iterator();
        EvaluationMetrics evaluationMetrics3 = it2.next().getEvaluationMetrics();
        int min = Math.min(10, (int) categoricalHisto2.getTotalCount());
        NumericHisto numericHisto = null;
        if (numericStats.getMin() != null) {
            numericHisto = new NumericHisto(numericStats.getMin().doubleValue(), numericStats.getMax().doubleValue(), min, evaluationMetrics3.getRmse());
        }
        NumericHisto numericHisto2 = null;
        if (numericStats2.getMin() != null) {
            numericHisto2 = new NumericHisto(numericStats2.getMin().doubleValue(), numericStats2.getMax().doubleValue(), min, evaluationMetrics3.getR2());
        }
        NumericHisto numericHisto3 = null;
        if (numericStats3.getMin() != null) {
            numericHisto3 = new NumericHisto(numericStats3.getMin().doubleValue(), numericStats3.getMax().doubleValue(), min, evaluationMetrics3.getMae());
        }
        NumericHisto numericHisto4 = null;
        if (numericStats4.getMin() != null) {
            numericHisto4 = new NumericHisto(numericStats4.getMin().doubleValue(), numericStats4.getMax().doubleValue(), min, evaluationMetrics3.getEvariance());
        }
        NumericHisto numericHisto5 = null;
        if (numericStats5.getMin() != null) {
            numericHisto5 = new NumericHisto(0.0d, 1.0d, 10, evaluationMetrics3.getPrecision());
        }
        NumericHisto numericHisto6 = null;
        if (numericStats6.getMin() != null) {
            numericHisto6 = new NumericHisto(0.0d, 1.0d, 10, evaluationMetrics3.getRecall());
        }
        NumericHisto numericHisto7 = null;
        if (numericStats7.getMin() != null) {
            numericHisto7 = new NumericHisto(0.0d, 1.0d, 10, evaluationMetrics3.getF1());
        }
        while (it2.hasNext()) {
            EvaluationMetrics evaluationMetrics4 = it2.next().getEvaluationMetrics();
            if (numericHisto != null) {
                numericHisto.update(evaluationMetrics4.getRmse());
            }
            if (numericHisto2 != null) {
                numericHisto2.update(evaluationMetrics4.getR2());
            }
            if (numericHisto3 != null) {
                numericHisto3.update(evaluationMetrics4.getMae());
            }
            if (numericHisto4 != null) {
                numericHisto4.update(evaluationMetrics4.getEvariance());
            }
            if (numericHisto5 != null) {
                numericHisto5.update(evaluationMetrics4.getPrecision());
            }
            if (numericHisto6 != null) {
                numericHisto6.update(evaluationMetrics4.getRecall());
            }
            if (numericHisto7 != null) {
                numericHisto7.update(evaluationMetrics4.getF1());
            }
        }
        if (numericHisto != null) {
            hashMap.put("rmse", new ColumnStats(numericHisto));
        }
        if (numericHisto2 != null) {
            hashMap.put("r2", new ColumnStats(numericHisto2));
        }
        if (numericHisto3 != null) {
            hashMap.put("mae", new ColumnStats(numericHisto3));
        }
        if (numericHisto4 != null) {
            hashMap.put("evariance", new ColumnStats(numericHisto4));
        }
        if (numericHisto5 != null) {
            hashMap.put("precision", new ColumnStats(numericHisto5));
        }
        if (numericHisto6 != null) {
            hashMap.put("recall", new ColumnStats(numericHisto6));
        }
        if (numericHisto7 != null) {
            hashMap.put("f1", new ColumnStats(numericHisto7));
        }
        return new ExperimentStats(experiment, hashMap, new ColumnStats(categoricalHisto), new ColumnStats(categoricalHisto2));
    }

    public void putExperiment(Experiment experiment) {
        this.experiments.put(experiment);
    }

    public void deleteExperiment(String str) {
        getExperiment(str);
        this.models.delete(str);
        this.splits.delete(str);
        this.experiments.delete(str);
    }

    public ModelsMeta listModels(String str, int i, int i2, SortInfo sortInfo) {
        getExperiment(str);
        return this.models.list(str, i, i2, sortInfo);
    }

    public ModelMeta getModel(ModelKey modelKey) {
        getExperiment(modelKey.getExperiment());
        ModelMeta modelMeta = this.models.get(modelKey);
        if (modelMeta == null) {
            throw new ModelNotFoundException(modelKey);
        }
        return modelMeta;
    }

    public ModelTrainerInfo trainModel(ModelKey modelKey, TrainModelRequest trainModelRequest, long j) {
        Experiment experiment = getExperiment(modelKey.getExperiment());
        ModelMeta model = getModel(modelKey);
        ModelStatus status = model.getStatus();
        if (status != ModelStatus.DATA_READY) {
            throw new ConflictException(String.format("Cannot train a model that is in the '%s' state.", status));
        }
        this.models.setTrainingInfo(modelKey, new TrainModelRequest(trainModelRequest.getAlgorithm(), trainModelRequest.getPredictionsDataset(), Modelers.getModeler(trainModelRequest.getAlgorithm()).getParams(trainModelRequest.getHyperparameters()).toMap()), j);
        return new ModelTrainerInfo(experiment, getSplit(new SplitKey(modelKey.getExperiment(), model.getSplit())), modelKey.getModel(), this.models.get(modelKey));
    }

    public void setModelSplit(ModelKey modelKey, String str) {
        Experiment experiment = getExperiment(modelKey.getExperiment());
        ModelMeta model = getModel(modelKey);
        ModelStatus status = model.getStatus();
        if (status != ModelStatus.PREPARING && status != ModelStatus.SPLIT_FAILED && status != ModelStatus.TRAINING_FAILED && status != ModelStatus.DATA_READY) {
            throw new ConflictException(String.format("Cannot set a split for a model in the '%s' state. The model must be in the '%s', '%s', '%s', or '%s' state.", status, ModelStatus.PREPARING, ModelStatus.SPLIT_FAILED, ModelStatus.TRAINING_FAILED, ModelStatus.DATA_READY));
        }
        DataSplitStats split = getSplit(new SplitKey(modelKey.getExperiment(), str));
        String split2 = model.getSplit();
        if (split2 != null) {
            this.splits.unregisterModel(new SplitKey(modelKey.getExperiment(), split2), modelKey.getModel());
        }
        this.models.setSplit(modelKey, split, experiment.getOutcome());
        this.splits.registerModel(new SplitKey(modelKey.getExperiment(), str), modelKey.getModel());
    }

    public void unassignModelSplit(ModelKey modelKey) {
        getExperiment(modelKey.getExperiment());
        ModelMeta model = getModel(modelKey);
        ModelStatus status = model.getStatus();
        if (status != ModelStatus.SPLIT_FAILED && status != ModelStatus.DATA_READY && status != ModelStatus.TRAINING_FAILED) {
            throw new ConflictException(String.format("Cannot unassign the split for a model in the '%s' state. The model must be in the '%s', '%s', or '%s' state.", status, ModelStatus.SPLIT_FAILED, ModelStatus.TRAINING_FAILED, ModelStatus.DATA_READY));
        }
        DataSplitStats split = getSplit(new SplitKey(modelKey.getExperiment(), model.getSplit()));
        this.models.unassignSplit(modelKey);
        this.models.setStatus(modelKey, ModelStatus.PREPARING);
        SplitKey splitKey = new SplitKey(modelKey.getExperiment(), model.getSplit());
        this.splits.unregisterModel(splitKey, modelKey.getModel());
        if (split.getModels().size() == 1) {
            this.splits.delete(splitKey);
        }
    }

    public String addModel(String str, CreateModelRequest createModelRequest) {
        Experiment experiment = getExperiment(str);
        String split = createModelRequest.getSplit();
        DataSplitStats dataSplitStats = null;
        if (split != null) {
            SplitKey splitKey = new SplitKey(str, split);
            dataSplitStats = this.splits.get(splitKey);
            if (dataSplitStats == null) {
                throw new SplitNotFoundException(splitKey);
            }
        }
        String add = this.models.add(experiment, createModelRequest, System.currentTimeMillis());
        if (dataSplitStats != null) {
            this.models.setSplit(new ModelKey(str, add), dataSplitStats, experiment.getOutcome());
        }
        return add;
    }

    public void setModelDirectives(ModelKey modelKey, List<String> list) {
        if (getModel(modelKey).getStatus() != ModelStatus.PREPARING) {
            throw new ConflictException(String.format("Directives can only be set or modified if the model is in the %s state.", ModelStatus.PREPARING));
        }
        this.models.setDirectives(modelKey, list);
    }

    public void updateModelMetrics(ModelKey modelKey, EvaluationMetrics evaluationMetrics, long j, Set<String> set) {
        this.models.update(modelKey, evaluationMetrics, j, set);
    }

    public void deleteModel(ModelKey modelKey) {
        ModelMeta modelMeta = this.models.get(modelKey);
        if (modelMeta == null) {
            throw new ModelNotFoundException(modelKey);
        }
        this.models.delete(modelKey);
        if (modelMeta.getSplit() != null) {
            this.splits.unregisterModel(new SplitKey(modelKey.getExperiment(), modelMeta.getSplit()), modelKey.getModel());
        }
    }

    public void deployModel(ModelKey modelKey) {
        if (getModel(modelKey).getDeploytime() > 0) {
            return;
        }
        this.models.setStatus(modelKey, ModelStatus.DEPLOYED);
    }

    public void modelFailed(ModelKey modelKey) {
        ModelStatus status = getModel(modelKey).getStatus();
        if (status != ModelStatus.TRAINING) {
            throw new IllegalStateException(String.format("Cannot transition model to '%s' from '%s'", status, ModelStatus.TRAINING_FAILED));
        }
        this.models.setStatus(modelKey, ModelStatus.TRAINING_FAILED);
    }

    public List<DataSplitStats> listSplits(String str) {
        getExperiment(str);
        return this.splits.list(str);
    }

    public DataSplitInfo addSplit(String str, DataSplit dataSplit, long j) {
        Experiment experiment = getExperiment(str);
        Schema.Type valueOf = Schema.Type.valueOf(experiment.getOutcomeType().toUpperCase());
        Schema.Field field = dataSplit.getSchema().getField(experiment.getOutcome());
        if (field == null) {
            throw new BadRequestException(String.format("Invalid split schema. The split must contain the experiment outcome '%s'.", experiment.getOutcome()));
        }
        Schema schema = field.getSchema();
        if (schema.isNullable()) {
            schema = schema.getNonNullable();
        }
        Schema.Type type = schema.getType();
        if (CATEGORICAL_TYPES.contains(valueOf) && !CATEGORICAL_TYPES.contains(type)) {
            throw new BadRequestException(String.format("Invalid split schema. Outcome field '%s' is of categorical type '%s' in the experiment , but is of non-categorical type '%s' in the split.", experiment.getOutcome(), valueOf, type));
        }
        if (NUMERIC_TYPES.contains(valueOf) && !NUMERIC_TYPES.contains(type)) {
            throw new BadRequestException(String.format("Invalid split schema. Outcome field '%s' is of numeric type '%s' in the experiment, but is of non-numeric type '%s' in the split.", experiment.getOutcome(), valueOf, type));
        }
        String addSplit = this.splits.addSplit(str, dataSplit, j);
        return new DataSplitInfo(addSplit, experiment, dataSplit, this.splits.getLocation(new SplitKey(str, addSplit)));
    }

    public DataSplitStats getSplit(SplitKey splitKey) {
        getExperiment(splitKey.getExperiment());
        DataSplitStats dataSplitStats = this.splits.get(splitKey);
        if (dataSplitStats == null) {
            throw new SplitNotFoundException(splitKey);
        }
        return dataSplitStats;
    }

    public void finishSplit(SplitKey splitKey, String str, String str2, List<ColumnSplitStats> list, long j) {
        this.splits.updateStats(splitKey, str, str2, list, j);
        Iterator<String> it = getSplit(splitKey).getModels().iterator();
        while (it.hasNext()) {
            this.models.setStatus(new ModelKey(splitKey.getExperiment(), it.next()), ModelStatus.DATA_READY);
        }
    }

    public void splitFailed(SplitKey splitKey, long j) {
        getExperiment(splitKey.getExperiment());
        DataSplitStats split = getSplit(splitKey);
        if (split.getStatus() != SplitStatus.SPLITTING) {
            throw new IllegalStateException("Cannot transition split to failed state unless it is in the splitting state.");
        }
        this.splits.splitFailed(splitKey, j);
        Iterator<String> it = split.getModels().iterator();
        while (it.hasNext()) {
            this.models.setStatus(new ModelKey(splitKey.getExperiment(), it.next()), ModelStatus.SPLIT_FAILED);
        }
    }

    public void deleteSplit(SplitKey splitKey) {
        DataSplitStats split = getSplit(splitKey);
        if (!split.getModels().isEmpty()) {
            throw new ConflictException(String.format("Cannot delete split '%s' since it is used by model(s) '%s'.", splitKey.getSplit(), Joiner.on(',').join((Iterable<?>) split.getModels())));
        }
        this.splits.delete(splitKey);
    }
}
