package com.gengoai.apollo.ml.model.sequence;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFTrainerByValueGradients;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.MalletLogger;
import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.LabelType;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.model.SingleSourceFitParameters;
import com.gengoai.apollo.ml.model.SingleSourceModel;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableSequence;
import com.gengoai.collection.Arrays2;
import com.gengoai.conversion.Cast;
import com.gengoai.function.Functional;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.Consumer;
import java.util.logging.Level;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/sequence/MalletCrf.class */
public class MalletCrf extends SingleSourceModel<Parameters, MalletCrf> {
    private static final long serialVersionUID = 1;
    public static final ParameterDef<Boolean> FULLY_CONNECTED = ParameterDef.boolParam("fullyConnected");
    public static final ParameterDef<Order> ORDER = ParameterDef.param("order", Order.class);
    public static final ParameterDef<String> START_STATE = ParameterDef.strParam("startState");
    public static final ParameterDef<Integer> THREADS = ParameterDef.intParam("numThreads");
    private SerialPipes pipes;
    private CRF model;
    private String startState;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/sequence/MalletCrf$Parameters.class */
    public static class Parameters extends SingleSourceFitParameters<Parameters> {
        private static final long serialVersionUID = 1;
        public final ParamMap<Parameters>.Parameter<SequenceValidator> validator = parameter(Params.Sequence.validator, SequenceValidator.ALWAYS_TRUE);
        public final ParamMap<Parameters>.Parameter<Integer> numberOfThreads = parameter(MalletCrf.THREADS, 20);
        public final ParamMap<Parameters>.Parameter<Order> order = parameter(MalletCrf.ORDER, Order.FIRST);
        public final ParamMap<Parameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 250);
        public final ParamMap<Parameters>.Parameter<Boolean> fullyConnected = parameter(MalletCrf.FULLY_CONNECTED, true);
        public final ParamMap<Parameters>.Parameter<String> startState = parameter(MalletCrf.START_STATE, "O");
    }

    public MalletCrf() {
        super(new Parameters());
    }

    public MalletCrf(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public MalletCrf(@NonNull Consumer<Parameters> consumer) {
        super((Parameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(DataSet dataSet) {
        if (((Boolean) ((Parameters) this.parameters).verbose.value()).booleanValue()) {
            MalletLogger.getLogger(ThreadedOptimizable.class.getName()).setLevel(Level.INFO);
            MalletLogger.getLogger(CRFTrainerByValueGradients.class.getName()).setLevel(Level.INFO);
            MalletLogger.getLogger(CRF.class.getName()).setLevel(Level.INFO);
            MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName()).setLevel(Level.INFO);
            MalletLogger.getLogger(LimitedMemoryBFGS.class.getName()).setLevel(Level.INFO);
        } else {
            MalletLogger.getLogger(ThreadedOptimizable.class.getName()).setLevel(Level.OFF);
            MalletLogger.getLogger(CRFTrainerByValueGradients.class.getName()).setLevel(Level.OFF);
            MalletLogger.getLogger(CRF.class.getName()).setLevel(Level.OFF);
            MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName()).setLevel(Level.OFF);
            MalletLogger.getLogger(LimitedMemoryBFGS.class.getName()).setLevel(Level.OFF);
        }
        Alphabet alphabet = new Alphabet();
        this.pipes = new SerialPipes(Arrays.asList(new SequenceToTokenSequence(), new TokenSequence2FeatureVectorSequence(alphabet, false, true)));
        this.pipes.setDataAlphabet(alphabet);
        this.pipes.setTargetAlphabet(new LabelAlphabet());
        InstanceList instanceList = new InstanceList(this.pipes);
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            Sequence<? extends Observation> asSequence = next.get(((Parameters) this.parameters).input.value()).asSequence();
            Sequence<? extends Observation> asSequence2 = next.get(((Parameters) this.parameters).output.value()).asSequence();
            Label[] labelArr = new Label[asSequence.size()];
            LabelAlphabet labelAlphabet = (LabelAlphabet) Cast.as(instanceList.getTargetAlphabet());
            for (int i = 0; i < labelArr.length; i++) {
                labelArr[i] = labelAlphabet.lookupLabel(((Observation) asSequence2.get(i)).asVariable().getName(), true);
            }
            instanceList.addThruPipe(new Instance(asSequence, new LabelSequence(labelArr), (Object) null, (Object) null));
        }
        this.model = new CRF(this.pipes, (Pipe) null);
        int[] iArr = new int[0];
        switch ((Order) ((Parameters) this.parameters).order.value()) {
            case FIRST:
                iArr = Arrays2.arrayOfInt(new int[]{1});
                break;
            case SECOND:
                iArr = Arrays2.arrayOfInt(new int[]{1, 2});
                break;
            case THIRD:
                iArr = Arrays2.arrayOfInt(new int[]{1, 2, 3});
                break;
        }
        MalletSequenceValidator malletSequenceValidator = (MalletSequenceValidator) Cast.as(((Parameters) this.parameters).validator.value() instanceof MalletSequenceValidator ? ((Parameters) this.parameters).validator.value() : null);
        this.model.addOrderNStates(instanceList, iArr, (boolean[]) null, (String) ((Parameters) this.parameters).startState.value(), malletSequenceValidator == null ? null : malletSequenceValidator.getForbidden(), malletSequenceValidator == null ? null : malletSequenceValidator.getAllowed(), ((Boolean) ((Parameters) this.parameters).fullyConnected.value()).booleanValue());
        this.startState = (String) ((Parameters) this.parameters).startState.value();
        this.model.setWeightsDimensionAsIn(instanceList, false);
        Optimizable.ByGradientValue threadedOptimizable = new ThreadedOptimizable(new CRFOptimizableByBatchLabelLikelihood(this.model, instanceList, ((Integer) ((Parameters) this.parameters).numberOfThreads.value()).intValue()), instanceList, this.model.getParameters().getNumFactors(), new CRFCacheStaleIndicator(this.model));
        CRFTrainerByValueGradients cRFTrainerByValueGradients = new CRFTrainerByValueGradients(this.model, new Optimizable.ByGradientValue[]{threadedOptimizable});
        cRFTrainerByValueGradients.setMaxResets(0);
        cRFTrainerByValueGradients.train(instanceList, ((Integer) ((Parameters) this.parameters).maxIterations.value()).intValue());
        threadedOptimizable.shutdown();
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return new Parameters();
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public LabelType getLabelType(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (((String) ((Parameters) this.parameters).output.value()).equals(str)) {
            return LabelType.Sequence;
        }
        throw new IllegalArgumentException("'" + str + "' is not a valid output for this model.");
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        observation.asSequence().size();
        cc.mallet.types.Sequence sequence = (cc.mallet.types.Sequence) Cast.as(this.model.getInputPipe().instanceFrom(new Instance(observation, (Object) null, (Object) null, (Object) null)).getData());
        cc.mallet.types.Sequence transduce = this.model.transduce(sequence);
        SumLatticeDefault sumLatticeDefault = new SumLatticeDefault(this.model, sequence, true);
        Transducer.State state = this.model.getState(this.startState);
        VariableSequence variableSequence = new VariableSequence();
        for (int i = 0; i < sequence.size(); i++) {
            Transducer.State state2 = this.model.getState((String) transduce.get(i));
            String str = (String) transduce.get(i);
            double max = Math.max(sumLatticeDefault.getGammaProbability(i, state2), sumLatticeDefault.getXiProbability(i, state, state2));
            state = state2;
            variableSequence.add(Variable.real(str, max));
        }
        return variableSequence;
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        dataSet.updateMetadata((String) ((Parameters) this.parameters).output.value(), observationMetadata -> {
            observationMetadata.setEncoder(null);
            observationMetadata.setType(VariableSequence.class);
            observationMetadata.setDimension(-1L);
        });
    }
}
