package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.MEMM;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.BitSet;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/MEMMTrainer.class */
public class MEMMTrainer extends TransducerTrainer {
    private static Logger logger;
    MEMM memm;
    private boolean gatheringTrainingData = false;
    private InstanceList trainingGatheredFor;
    MEMMOptimizableByLabelLikelihood omemm;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/fst/MEMMTrainer$MEMMOptimizableByLabelLikelihood.class */
    public class MEMMOptimizableByLabelLikelihood extends CRFOptimizableByLabelLikelihood implements Optimizable.ByGradientValue {
        BitSet infiniteValues;

        protected MEMMOptimizableByLabelLikelihood(MEMM memm, InstanceList instanceList) {
            super(memm, instanceList);
            this.infiniteValues = null;
            this.expectations = new CRF.Factors(memm);
            this.constraints = new CRF.Factors(memm);
        }

        protected double gatherExpectationsOrConstraints(boolean z) {
            boolean z2 = false;
            CRF.Factors factors = z ? this.constraints : this.expectations;
            factors.getClass();
            CRF.Factors.Incrementor incrementor = new CRF.Factors.Incrementor();
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                z2 = true;
            }
            double d = 0.0d;
            for (int i = 0; i < MEMMTrainer.this.memm.numStates(); i++) {
                MEMM.State state = (MEMM.State) MEMMTrainer.this.memm.getState(i);
                if (state.trainingSet == null) {
                    System.out.println("Empty training set for state " + state.name);
                } else {
                    for (int i2 = 0; i2 < state.trainingSet.size(); i2++) {
                        Instance instance = state.trainingSet.get(i2);
                        double instanceWeight = state.trainingSet.getInstanceWeight(i2);
                        FeatureVector featureVector = (FeatureVector) instance.getData();
                        String str = (String) instance.getTarget();
                        MEMM.TransitionIterator transitionIterator = new MEMM.TransitionIterator(state, featureVector, z ? str : null, MEMMTrainer.this.memm);
                        while (transitionIterator.hasNext()) {
                            transitionIterator.nextState();
                            double weight = transitionIterator.getWeight();
                            incrementor.incrementTransition(transitionIterator, Math.exp(weight) * instanceWeight);
                            if (!z && transitionIterator.getOutput() == str) {
                                if (Double.isInfinite(weight)) {
                                    MEMMTrainer.logger.warning("State " + i + " transition " + i2 + " has infinite cost; skipping.");
                                    if (z2) {
                                        throw new IllegalStateException("Infinite-cost transitions not yet supported");
                                    }
                                    if (!this.infiniteValues.get(i2)) {
                                        throw new IllegalStateException("Instance i used to have non-infinite value, but now it has infinite value.");
                                    }
                                } else {
                                    d += instanceWeight * weight;
                                }
                            }
                        }
                    }
                }
            }
            for (int i3 = 0; i3 < MEMMTrainer.this.memm.numStates(); i3++) {
                factors.initialWeights[i3] = 0.0d;
                factors.finalWeights[i3] = 0.0d;
            }
            return d;
        }

        @Override // cc.mallet.fst.CRFOptimizableByLabelLikelihood
        protected double getExpectationValue() {
            return gatherExpectationsOrConstraints(false);
        }
    }

    static {
        $assertionsDisabled = !MEMMTrainer.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(MEMMTrainer.class.getName());
    }

    public MEMMTrainer(MEMM memm) {
        this.memm = memm;
    }

    public MEMMOptimizableByLabelLikelihood getOptimizableMEMM(InstanceList instanceList) {
        return new MEMMOptimizableByLabelLikelihood(this.memm, instanceList);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList) {
        return train(instanceList, Integer.MAX_VALUE);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        if (this.trainingGatheredFor != instanceList) {
            gatherTrainingSets(instanceList);
        }
        this.omemm = new MEMMOptimizableByLabelLikelihood(this.memm, instanceList);
        this.omemm.gatherExpectationsOrConstraints(true);
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS(this.omemm);
        boolean z = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            try {
                z = limitedMemoryBFGS.optimize(1);
                logger.info("CRF finished one iteration of maximizer, i=" + i2);
                runEvaluators();
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            }
            if (z) {
                logger.info("CRF training has converged, i=" + i2);
                break;
            }
            i2++;
        }
        logger.info("About to setTrainable(false)");
        return z;
    }

    void gatherTrainingSets(InstanceList instanceList) {
        if (this.trainingGatheredFor != null) {
            throw new UnsupportedOperationException("Training with multiple sets not supported.");
        }
        this.trainingGatheredFor = instanceList;
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            new SumLatticeDefault(this.memm, (FeatureVectorSequence) instance.getData(), (FeatureSequence) instance.getTarget(), new Transducer.Incrementor() { // from class: cc.mallet.fst.MEMMTrainer.1
                @Override // cc.mallet.fst.Transducer.Incrementor
                public void incrementFinalState(Transducer.State state, double d) {
                }

                @Override // cc.mallet.fst.Transducer.Incrementor
                public void incrementInitialState(Transducer.State state, double d) {
                }

                @Override // cc.mallet.fst.Transducer.Incrementor
                public void incrementTransition(Transducer.TransitionIterator transitionIterator, double d) {
                    MEMM.State state = (MEMM.State) transitionIterator.getSourceState();
                    if (d != 0.0d) {
                        if (state.trainingSet == null) {
                            state.trainingSet = new InstanceList(null);
                        }
                        state.trainingSet.add(new Instance(transitionIterator.getInput(), transitionIterator.getOutput(), null, null), d);
                    }
                }
            });
        }
    }

    public boolean train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, double[] dArr) {
        throw new UnsupportedOperationException();
    }

    public boolean trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, int i3, int i4, double d, boolean z, double[] dArr, String str) {
        throw new UnsupportedOperationException();
    }

    public void printInstanceLists() {
        for (int i = 0; i < this.memm.numStates(); i++) {
            MEMM.State state = (MEMM.State) this.memm.getState(i);
            InstanceList instanceList = state.trainingSet;
            System.out.println("State " + i + " : " + state.getName());
            if (instanceList == null) {
                System.out.println("No data");
            } else {
                for (int i2 = 0; i2 < instanceList.size(); i2++) {
                    Instance instance = instanceList.get(i2);
                    System.out.println("From : " + state.getName() + " To : " + instance.getTarget());
                    System.out.println("Instance " + i2);
                    System.out.println(instance.getTarget());
                    System.out.println(instance.getData());
                }
            }
        }
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return 0;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.memm;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return false;
    }
}
