package io.cdap.mmds.data;

import com.google.common.base.Ascii;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import io.cdap.cdap.api.common.Bytes;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.dataset.DatasetProperties;
import io.cdap.cdap.api.dataset.lib.IndexedTable;
import io.cdap.cdap.api.dataset.module.EmbeddedDataset;
import io.cdap.cdap.api.dataset.table.Put;
import io.cdap.cdap.api.dataset.table.Row;
import io.cdap.cdap.api.dataset.table.Scan;
import io.cdap.cdap.api.dataset.table.Scanner;
import io.cdap.mmds.proto.CreateModelRequest;
import io.cdap.mmds.proto.TrainModelRequest;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import javax.annotation.Nullable;

/* loaded from: input_file:lib/mmds-model-1.7.1.jar:io/cdap/mmds/data/ModelTable.class */
public class ModelTable extends CountTable<IndexedTable> {
    private static final String SEPARATOR = "/";
    private static final String EXPERIMENT_COL = "experiment";
    private static final String ID_COL = "id";
    private static final String DESC_COL = "description";
    private static final String ALGO_COL = "algorithm";
    private static final String SPLIT_COL = "split";
    private static final String OUTCOME_COL = "outcome";
    private static final String HYPER_PARAMS_COL = "hyperparameters";
    private static final String FEATURES_COL = "features";
    private static final String CATEGORICAL_FEATURES_COL = "catfeatures";
    private static final String CREATE_TIME_COL = "createtime";
    private static final String TRAINING_TIME_COL = "trainingtime";
    private static final String TRAIN_TIME_COL = "trainedtime";
    private static final String DEPLOY_TIME_COL = "deploytime";
    private static final String STATUS_COL = "status";
    private static final String DIRECTIVES_COL = "directives";
    private static final String PREDICTIONS_COL = "predictions";
    private static final String PRECISION_COL = "precision";
    private static final String RECALL_COL = "recall";
    private static final String F1_COL = "f1";
    private static final String RMSE_COL = "rmse";
    private static final String R2_COL = "r2";
    private static final String EVARIANCE_COL = "evariance";
    private static final String MAE_COL = "mae";
    private static final Gson GSON = new Gson();
    private static Type MAP_TYPE = new TypeToken<Map<String, String>>() { // from class: io.cdap.mmds.data.ModelTable.1
    }.getType();
    private static Type LIST_TYPE = new TypeToken<List<String>>() { // from class: io.cdap.mmds.data.ModelTable.2
    }.getType();
    private static Type SET_TYPE = new TypeToken<Set<String>>() { // from class: io.cdap.mmds.data.ModelTable.3
    }.getType();
    private static final String NAME_COL = "name";
    public static final DatasetProperties DATASET_PROPERTIES = DatasetProperties.builder().add(IndexedTable.INDEX_COLUMNS_CONF_KEY, NAME_COL).build();

    /* renamed from: io.cdap.mmds.data.ModelTable$6, reason: invalid class name */
    /* loaded from: input_file:lib/mmds-model-1.7.1.jar:io/cdap/mmds/data/ModelTable$6.class */
    static /* synthetic */ class AnonymousClass6 {
        static final /* synthetic */ int[] $SwitchMap$io$cdap$mmds$data$SplitStatus = new int[SplitStatus.values().length];

        static {
            try {
                $SwitchMap$io$cdap$mmds$data$SplitStatus[SplitStatus.SPLITTING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$cdap$mmds$data$SplitStatus[SplitStatus.FAILED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$cdap$mmds$data$SplitStatus[SplitStatus.COMPLETE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public ModelTable(IndexedTable indexedTable) {
        super(indexedTable);
    }

    public ModelsMeta list(String str, int i, int i2, SortInfo sortInfo) {
        SortType sortType = sortInfo.getSortType();
        ArrayList arrayList = new ArrayList();
        byte[] bytes = Bytes.toBytes(str + SEPARATOR);
        Scanner scan = ((IndexedTable) this.table).scan(bytes, Bytes.stopKeyForPrefix(bytes));
        Throwable th = null;
        while (true) {
            try {
                try {
                    Row next = scan.next();
                    if (next == null) {
                        break;
                    }
                    arrayList.add(fromRow(next));
                } catch (Throwable th2) {
                    th = th2;
                    throw th2;
                }
            } catch (Throwable th3) {
                if (scan != null) {
                    if (th != null) {
                        try {
                            scan.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scan.close();
                    }
                }
                throw th3;
            }
        }
        if (scan != null) {
            if (0 != 0) {
                try {
                    scan.close();
                } catch (Throwable th5) {
                    th.addSuppressed(th5);
                }
            } else {
                scan.close();
            }
        }
        Collections.sort(arrayList, sortType.equals(SortType.DESC) ? new Comparator<ModelMeta>() { // from class: io.cdap.mmds.data.ModelTable.4
            @Override // java.util.Comparator
            public int compare(ModelMeta modelMeta, ModelMeta modelMeta2) {
                return modelMeta2.getName().compareTo(modelMeta.getName());
            }
        } : new Comparator<ModelMeta>() { // from class: io.cdap.mmds.data.ModelTable.5
            @Override // java.util.Comparator
            public int compare(ModelMeta modelMeta, ModelMeta modelMeta2) {
                return modelMeta.getName().compareTo(modelMeta2.getName());
            }
        });
        return arrayList.isEmpty() ? new ModelsMeta(arrayList.size(), arrayList) : new ModelsMeta(arrayList.size(), arrayList.subList(i, Math.min(i + i2, arrayList.size())));
    }

    @Nullable
    public ModelMeta get(ModelKey modelKey) {
        Row row = ((IndexedTable) this.table).get(getKey(modelKey));
        if (row.isEmpty()) {
            return null;
        }
        return fromRow(row);
    }

    public void setStatus(ModelKey modelKey, ModelStatus modelStatus) {
        Put add = new Put(getKey(modelKey)).add(STATUS_COL, modelStatus.name());
        long currentTimeMillis = System.currentTimeMillis();
        if (modelStatus == ModelStatus.DEPLOYED) {
            add.add(DEPLOY_TIME_COL, currentTimeMillis);
        } else if (modelStatus == ModelStatus.TRAINED) {
            add.add(TRAIN_TIME_COL, currentTimeMillis);
        } else if (modelStatus == ModelStatus.TRAINING_FAILED) {
            add.add(TRAIN_TIME_COL, currentTimeMillis);
        }
        ((IndexedTable) this.table).put(add);
    }

    public void delete(ModelKey modelKey) {
        ((IndexedTable) this.table).delete(getKey(modelKey));
        decrementRowCount(1, modelKey.getExperiment());
    }

    public int delete(String str) {
        int delete = delete(str, Integer.MAX_VALUE);
        decrementRowCount(delete, str);
        return delete;
    }

    public int delete(String str, int i) {
        byte[] bytes = Bytes.toBytes(str + SEPARATOR);
        Scan scan = new Scan(bytes, Bytes.stopKeyForPrefix(bytes));
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        Scanner scan2 = ((IndexedTable) this.table).scan(scan);
        Throwable th = null;
        do {
            try {
                try {
                    Row next = scan2.next();
                    if (next == null) {
                        break;
                    }
                    arrayList.add(next.getRow());
                    i2++;
                } finally {
                }
            } catch (Throwable th2) {
                if (scan2 != null) {
                    if (th != null) {
                        try {
                            scan2.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        scan2.close();
                    }
                }
                throw th2;
            }
        } while (i2 < i);
        if (scan2 != null) {
            if (0 != 0) {
                try {
                    scan2.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                scan2.close();
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((IndexedTable) this.table).delete((byte[]) it.next());
        }
        return i2;
    }

    public String add(Experiment experiment, CreateModelRequest createModelRequest, long j) {
        String replaceAll = UUID.randomUUID().toString().replaceAll("-", EmbeddedDataset.DEFAULT_TYPE_NAME);
        Put add = new Put(getKey(experiment.getName(), replaceAll)).add(EXPERIMENT_COL, experiment.getName()).add(ID_COL, replaceAll).add(NAME_COL, createModelRequest.getName()).add(DESC_COL, createModelRequest.getDescription()).add(OUTCOME_COL, experiment.getOutcome()).add(CREATE_TIME_COL, j).add(STATUS_COL, ModelStatus.PREPARING.name()).add(TRAIN_TIME_COL, -1L).add(DEPLOY_TIME_COL, -1L);
        if (!createModelRequest.getDirectives().isEmpty()) {
            add.add(DIRECTIVES_COL, GSON.toJson(createModelRequest.getDirectives()));
        }
        ((IndexedTable) this.table).put(add);
        incrementRowCount(experiment.getName());
        return replaceAll;
    }

    public void setDirectives(ModelKey modelKey, List<String> list) {
        if (list.isEmpty()) {
            return;
        }
        ((IndexedTable) this.table).put(new Put(getKey(modelKey)).add(DIRECTIVES_COL, GSON.toJson(list)));
    }

    public void setSplit(ModelKey modelKey, DataSplitStats dataSplitStats, String str) {
        ModelStatus modelStatus;
        switch (AnonymousClass6.$SwitchMap$io$cdap$mmds$data$SplitStatus[dataSplitStats.getStatus().ordinal()]) {
            case 1:
                modelStatus = ModelStatus.SPLITTING;
                break;
            case 2:
                modelStatus = ModelStatus.SPLIT_FAILED;
                break;
            case Ascii.ETX /* 3 */:
                modelStatus = ModelStatus.DATA_READY;
                break;
            default:
                throw new IllegalStateException("Unknown split status " + dataSplitStats.getStatus());
        }
        Schema schema = dataSplitStats.getSchema();
        ArrayList arrayList = new ArrayList(schema.getFields().size() - 1);
        Iterator<Schema.Field> it = schema.getFields().iterator();
        while (it.hasNext()) {
            String name = it.next().getName();
            if (!name.equals(str)) {
                arrayList.add(name);
            }
        }
        ((IndexedTable) this.table).put(new Put(getKey(modelKey)).add(SPLIT_COL, dataSplitStats.getId()).add(STATUS_COL, modelStatus.name()).add(DIRECTIVES_COL, GSON.toJson(dataSplitStats.getDirectives())).add(FEATURES_COL, GSON.toJson(arrayList)));
    }

    public void unassignSplit(ModelKey modelKey) {
        ((IndexedTable) this.table).delete(getKey(modelKey), Bytes.toBytes(SPLIT_COL));
    }

    public void setTrainingInfo(ModelKey modelKey, TrainModelRequest trainModelRequest, long j) {
        Put add = new Put(getKey(modelKey)).add(ALGO_COL, trainModelRequest.getAlgorithm()).add(HYPER_PARAMS_COL, GSON.toJson(trainModelRequest.getHyperparameters())).add(STATUS_COL, ModelStatus.TRAINING.name()).add(TRAINING_TIME_COL, j).add(TRAIN_TIME_COL, -1L);
        if (trainModelRequest.getPredictionsDataset() != null) {
            add.add(PREDICTIONS_COL, trainModelRequest.getPredictionsDataset());
        }
        ((IndexedTable) this.table).put(add);
    }

    public void update(ModelKey modelKey, EvaluationMetrics evaluationMetrics, long j, Set<String> set) {
        Put put = new Put(getKey(modelKey));
        if (evaluationMetrics.getPrecision() != null) {
            put.add(PRECISION_COL, evaluationMetrics.getPrecision().doubleValue());
        }
        if (evaluationMetrics.getRecall() != null) {
            put.add(RECALL_COL, evaluationMetrics.getRecall().doubleValue());
        }
        if (evaluationMetrics.getF1() != null) {
            put.add(F1_COL, evaluationMetrics.getF1().doubleValue());
        }
        if (evaluationMetrics.getRmse() != null) {
            put.add(RMSE_COL, evaluationMetrics.getRmse().doubleValue());
        }
        if (evaluationMetrics.getR2() != null) {
            put.add(R2_COL, evaluationMetrics.getR2().doubleValue());
        }
        if (evaluationMetrics.getEvariance() != null) {
            put.add(EVARIANCE_COL, evaluationMetrics.getEvariance().doubleValue());
        }
        if (evaluationMetrics.getMae() != null) {
            put.add(MAE_COL, evaluationMetrics.getMae().doubleValue());
        }
        put.add(STATUS_COL, ModelStatus.TRAINED.name());
        put.add(TRAIN_TIME_COL, j);
        put.add(CATEGORICAL_FEATURES_COL, GSON.toJson(set));
        ((IndexedTable) this.table).put(put);
    }

    private ModelMeta fromRow(Row row) {
        String bytes = Bytes.toString(row.getRow());
        String substring = bytes.substring(bytes.indexOf(SEPARATOR) + 1);
        Map<String, String> map = (Map) GSON.fromJson(row.getString(HYPER_PARAMS_COL), MAP_TYPE);
        Map<String, String> hashMap = map == null ? new HashMap<>() : map;
        List<String> list = (List) GSON.fromJson(row.getString(FEATURES_COL), LIST_TYPE);
        List<String> arrayList = list == null ? new ArrayList<>() : list;
        Set<String> set = (Set) GSON.fromJson(row.getString(CATEGORICAL_FEATURES_COL), SET_TYPE);
        Set<String> hashSet = set == null ? new HashSet<>() : set;
        String string = row.getString(DESC_COL);
        String str = string == null ? EmbeddedDataset.DEFAULT_TYPE_NAME : string;
        String string2 = row.getString(STATUS_COL);
        ModelStatus valueOf = string2 == null ? null : ModelStatus.valueOf(string2);
        String string3 = row.getString(DIRECTIVES_COL);
        return ModelMeta.builder(substring).setName(row.getString(NAME_COL)).setDescription(str).setOutcome(row.getString(OUTCOME_COL)).setAlgorithm(row.getString(ALGO_COL)).setSplit(row.getString(SPLIT_COL)).setHyperParameters(hashMap).setFeatures(arrayList).setStatus(valueOf).setCategoricalFeatures(hashSet).setCreateTime(row.getLong(CREATE_TIME_COL, -1L)).setTrainedTime(row.getLong(TRAIN_TIME_COL, -1L)).setTrainingTime(row.getLong(TRAINING_TIME_COL, -1L)).setDeployTime(row.getLong(DEPLOY_TIME_COL, -1L)).setEvaluationMetrics(new EvaluationMetrics(row.getDouble(PRECISION_COL), row.getDouble(RECALL_COL), row.getDouble(F1_COL), row.getDouble(RMSE_COL), row.getDouble(R2_COL), row.getDouble(EVARIANCE_COL), row.getDouble(MAE_COL))).setDirectives(string3 == null ? new ArrayList<>() : (List) GSON.fromJson(string3, LIST_TYPE)).setPredictionsDataset(row.getString(PREDICTIONS_COL)).build();
    }

    private byte[] getKey(ModelKey modelKey) {
        return getKey(modelKey.getExperiment(), modelKey.getModel());
    }

    private byte[] getKey(String str, String str2) {
        return Bytes.toBytes(str + SEPARATOR + str2);
    }
}
