package com.gengoai.apollo.ml.model.sequence;

import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.LabelType;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.model.SingleSourceFitParameters;
import com.gengoai.apollo.ml.model.SingleSourceModel;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableList;
import com.gengoai.apollo.ml.observation.VariableSequence;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import com.gengoai.io.Resources;
import com.gengoai.io.resource.Resource;
import com.gengoai.jcrfsuite.CrfTagger;
import com.gengoai.jcrfsuite.util.Pair;
import com.gengoai.tuple.Tuple2;
import com.gengoai.tuple.Tuples;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Base64;
import java.util.Iterator;
import java.util.function.Consumer;
import lombok.NonNull;
import third_party.org.chokkan.crfsuite.Attribute;
import third_party.org.chokkan.crfsuite.Item;
import third_party.org.chokkan.crfsuite.ItemSequence;
import third_party.org.chokkan.crfsuite.StringList;
import third_party.org.chokkan.crfsuite.Trainer;

/* loaded from: input_file:com/gengoai/apollo/ml/model/sequence/Crf.class */
public class Crf extends SingleSourceModel<Parameters, Crf> {
    private static final long serialVersionUID = 1;
    public static final ParameterDef<Double> C1 = ParameterDef.doubleParam("c1");
    public static final ParameterDef<Double> C2 = ParameterDef.doubleParam("c2");
    public static final ParameterDef<Double> EPS = ParameterDef.doubleParam("eps");
    public static final ParameterDef<Integer> MIN_FEATURE_FREQ = ParameterDef.intParam("minFeatureFreq");
    public static final ParameterDef<CrfSolver> SOLVER = ParameterDef.param("solver", CrfSolver.class);
    protected String modelFile;
    protected volatile CrfTagger tagger;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/sequence/Crf$Parameters.class */
    public static class Parameters extends SingleSourceFitParameters<Parameters> {
        private static final long serialVersionUID = 1;
        public final ParamMap<Parameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 250);
        public final ParamMap<Parameters>.Parameter<CrfSolver> crfSolver = parameter(Crf.SOLVER, CrfSolver.LBFGS);
        public final ParamMap<Parameters>.Parameter<Double> c1 = parameter(Crf.C1, Double.valueOf(0.0d));
        public final ParamMap<Parameters>.Parameter<Double> c2 = parameter(Crf.C2, Double.valueOf(1.0d));
        public final ParamMap<Parameters>.Parameter<Double> eps = parameter(Crf.EPS, Double.valueOf(1.0E-5d));
        public final ParamMap<Parameters>.Parameter<Integer> minFeatureFreq = parameter(Crf.MIN_FEATURE_FREQ, 0);
    }

    public Crf() {
        super(new Parameters());
    }

    public Crf(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public Crf(@NonNull Consumer<Parameters> consumer) {
        super((Parameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    private Item createItem(Observation observation) {
        Item item = new Item();
        if (observation instanceof VariableCollection) {
            Iterator<Variable> it = ((VariableList) Cast.as(observation)).iterator();
            while (it.hasNext()) {
                Variable next = it.next();
                item.add(new Attribute(next.getName(), next.getValue()));
            }
        } else {
            if (!(observation instanceof Variable)) {
                throw new IllegalStateException("Unsupported type: " + observation.getClass());
            }
            Variable variable = (Variable) Cast.as(observation);
            item.add(new Attribute(variable.getName(), variable.getValue()));
        }
        return item;
    }

    private Item createItem(NDArray nDArray) {
        Item item = new Item();
        nDArray.forEachSparse((j, d) -> {
            item.add(new Attribute(Long.toString(j), d));
        });
        return item;
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        CrfSuiteLoader.INSTANCE.load();
        Trainer trainer = new Trainer();
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            Tuple2<ItemSequence, StringList> itemSequence = toItemSequence((Observation) Validation.notNull(next.get(((Parameters) this.parameters).input.value()), "Null Input Observation"), (Sequence) Validation.notNull(next.get(((Parameters) this.parameters).output.value()).asSequence(), "Null Output Observation"));
            trainer.append((ItemSequence) itemSequence.v1, (StringList) itemSequence.v2, 0);
        }
        trainer.select(((CrfSolver) ((Parameters) this.parameters).crfSolver.value()).parameterSetting, "crf1d");
        trainer.set("max_iterations", Integer.toString(((Integer) ((Parameters) this.parameters).maxIterations.value()).intValue()));
        trainer.set("c2", Double.toString(((Double) ((Parameters) this.parameters).c2.value()).doubleValue()));
        trainer.set("c1", Double.toString(((Double) ((Parameters) this.parameters).c1.value()).doubleValue()));
        trainer.set("epsilon", Double.toString(((Double) ((Parameters) this.parameters).eps.value()).doubleValue()));
        trainer.set("feature.minfreq", Integer.toString(((Integer) ((Parameters) this.parameters).minFeatureFreq.value()).intValue()));
        this.modelFile = ((File) Resources.temporaryFile().asFile().orElseThrow(IllegalArgumentException::new)).getAbsolutePath();
        trainer.train(this.modelFile, -1);
        trainer.clear();
        this.tagger = new CrfTagger(this.modelFile);
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return (Parameters) this.parameters;
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public LabelType getLabelType(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (((String) ((Parameters) this.parameters).output.value()).equals(str)) {
            return LabelType.Sequence;
        }
        throw new IllegalArgumentException("'" + str + "' is not a valid output for this model.");
    }

    private void readObject(ObjectInputStream objectInputStream) throws Exception {
        CrfSuiteLoader.INSTANCE.load();
        Resource temporaryFile = Resources.temporaryFile();
        byte[] bArr = new byte[objectInputStream.readInt()];
        objectInputStream.readFully(bArr);
        temporaryFile.write(Base64.getDecoder().decode(bArr));
        this.modelFile = ((File) temporaryFile.asFile().orElseThrow(IllegalArgumentException::new)).getAbsolutePath();
        this.tagger = new CrfTagger(this.modelFile);
    }

    private Tuple2<ItemSequence, StringList> toItemSequence(Observation observation, Sequence<? extends Observation> sequence) {
        ItemSequence itemSequence = new ItemSequence();
        StringList stringList = new StringList();
        if (observation instanceof Sequence) {
            ((Sequence) Cast.as(observation)).forEach(observation2 -> {
                itemSequence.add(createItem(observation2));
            });
        } else {
            if (!(observation instanceof NDArray)) {
                throw new IllegalArgumentException("Observations of type '" + observation.getClass() + "' are not supported as input");
            }
            NDArray nDArray = (NDArray) Cast.as(observation);
            for (int i = 0; i < nDArray.rows(); i++) {
                itemSequence.add(createItem(nDArray.getRow(i)));
            }
        }
        if (sequence != null) {
            for (Observation observation3 : sequence) {
                if (!(observation3 instanceof Variable)) {
                    throw new IllegalArgumentException("Observations of type '" + observation3.getClass() + "' are not supported as an output");
                }
                stringList.add(((Variable) Cast.as(observation3)).getName());
            }
        }
        return Tuples.$(itemSequence, stringList);
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        CrfSuiteLoader.INSTANCE.load();
        ItemSequence itemSequence = (ItemSequence) toItemSequence(observation, null).v1;
        VariableSequence variableSequence = new VariableSequence();
        for (Pair pair : this.tagger.tag(itemSequence)) {
            variableSequence.add(new Variable((String) pair.first, ((Double) pair.second).doubleValue()));
        }
        return variableSequence;
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        dataSet.updateMetadata((String) ((Parameters) this.parameters).output.value(), observationMetadata -> {
            observationMetadata.setType(VariableSequence.class);
        });
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        byte[] encode = Base64.getEncoder().encode(Resources.from(this.modelFile).readBytes());
        objectOutputStream.writeInt(encode.length);
        objectOutputStream.write(encode);
    }
}
