package io.cdap.mmds.modeler.train;

import io.cdap.cdap.api.Admin;
import io.cdap.cdap.api.Transactional;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.dataset.lib.PartitionDetail;
import io.cdap.cdap.api.dataset.lib.PartitionKey;
import io.cdap.cdap.api.dataset.lib.PartitionOutput;
import io.cdap.cdap.api.dataset.lib.PartitionedFileSet;
import io.cdap.cdap.api.dataset.lib.PartitionedFileSetProperties;
import io.cdap.cdap.api.dataset.lib.Partitioning;
import io.cdap.mmds.Constants;
import io.cdap.mmds.Schemas;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.data.ModelKey;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.apache.spark.sql.SaveMode;
import org.apache.twill.filesystem.Location;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:lib/mmds-model-1.10.0.jar:io/cdap/mmds/modeler/train/ModelOutputWriter.class */
public class ModelOutputWriter {
    private static final Logger LOG = LoggerFactory.getLogger(ModelOutputWriter.class);
    private final Admin admin;
    private final Transactional transactional;
    private final Location baseLocation;
    private final boolean overwrite;

    public ModelOutputWriter(Admin admin, Transactional transactional, Location location, boolean z) {
        this.admin = admin;
        this.transactional = transactional;
        this.baseLocation = location;
        this.overwrite = z;
    }

    public void save(ModelKey modelKey, ModelOutput modelOutput, @Nullable String str) throws Exception {
        if (modelOutput.getTargetIndexModel() != null) {
            LOG.info("Saving outcome indices...");
            modelOutput.getTargetIndexModel().save(getPath(modelKey, Constants.Component.TARGET_INDICES));
            LOG.info("Outcome indices successfully saved.");
        }
        LOG.info("Saving feature generation pipeline...");
        modelOutput.getFeatureGenModel().write().overwrite().save(getPath(modelKey, Constants.Component.FEATUREGEN));
        LOG.info("Feature generation pipeline successfully saved.");
        LOG.info("Saving trained model...");
        modelOutput.getModel().write().overwrite().save(getPath(modelKey, Constants.Component.MODEL));
        LOG.info("Model successfully saved.");
        if (str != null) {
            if (!this.admin.datasetExists(str)) {
                ArrayList arrayList = new ArrayList();
                arrayList.add(Schema.Field.of("prediction", Schema.of(modelOutput.getAlgorithmType() == AlgorithmType.REGRESSION ? Schema.Type.DOUBLE : Schema.Type.STRING)));
                arrayList.addAll(modelOutput.getSchema().getFields());
                this.admin.createDataset(str, PartitionedFileSet.class.getName(), PartitionedFileSetProperties.builder().setPartitioning(Partitioning.builder().addStringField("experiment").addStringField(Constants.Component.MODEL).build()).setEnableExploreOnCreate(true).setExploreFormat("text").setExploreFormatProperty("delimiter", ",").setExploreSchema(Schemas.toHiveSchema(Schema.recordOf(modelOutput.getSchema().getRecordName() + ".prediction", arrayList))).build());
            }
            PartitionKey build = PartitionKey.builder().addStringField(Constants.Component.MODEL, modelKey.getModel()).addStringField("experiment", modelKey.getExperiment()).build();
            AtomicReference atomicReference = new AtomicReference();
            this.transactional.execute(datasetContext -> {
                PartitionedFileSet dataset = datasetContext.getDataset(str);
                PartitionDetail partition = dataset.getPartition(build);
                if (partition != null) {
                    atomicReference.set(partition.getLocation().toURI().getPath());
                    return;
                }
                PartitionOutput partitionOutput = dataset.getPartitionOutput(build);
                atomicReference.set(partitionOutput.getLocation().toURI().getPath());
                partitionOutput.addPartition();
            });
            modelOutput.getPredictions().write().format("csv").mode(SaveMode.Overwrite).save((String) atomicReference.get());
            LOG.info("Predictions on training data successfully saved.");
        }
    }

    public void deleteComponents(ModelKey modelKey) throws IOException {
        deleteComponent(modelKey, Constants.Component.TARGET_INDICES);
        deleteComponent(modelKey, Constants.Component.FEATUREGEN);
        deleteComponent(modelKey, Constants.Component.MODEL);
    }

    private String getPath(ModelKey modelKey, String str) throws IOException {
        Location append = this.baseLocation.append(modelKey.getExperiment()).append(modelKey.getModel()).append(str);
        if (append.exists()) {
            if (!this.overwrite) {
                throw new IllegalArgumentException(append + " already exists.");
            }
            append.delete();
        }
        return append.toURI().getPath();
    }

    private void deleteComponent(ModelKey modelKey, String str) throws IOException {
        Location append = this.baseLocation.append(modelKey.getExperiment()).append(modelKey.getModel()).append(str);
        if (append.exists()) {
            append.delete();
        }
    }
}
