package com.gengoai.apollo.ml.model;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.encoder.MalletEncoder;
import com.gengoai.apollo.ml.model.SingleSourceFitParameters;
import com.gengoai.apollo.ml.observation.Classification;
import com.gengoai.apollo.ml.observation.Observation;
import java.util.Arrays;
import java.util.Iterator;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/MalletClassifier.class */
public abstract class MalletClassifier<T extends SingleSourceFitParameters<T>> extends SingleSourceModel<T, MalletClassifier<T>> {
    private static final long serialVersionUID = 1;
    protected Alphabet featureAlphabet;
    protected Alphabet labelAlphabet;
    protected Classifier model;

    /* JADX INFO: Access modifiers changed from: protected */
    public MalletClassifier(@NonNull T t) {
        super(t);
        if (t == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.featureAlphabet = new Alphabet();
        SerialPipes serialPipes = new SerialPipes(Arrays.asList(new Target2Label(), new VectorToTokensPipe(this.featureAlphabet)));
        serialPipes.setDataAlphabet(this.featureAlphabet);
        InstanceList instanceList = new InstanceList(serialPipes);
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            Observation observation = next.get(this.parameters.input.value());
            Observation observation2 = next.get(this.parameters.output.value());
            Validation.notNull(observation, "Null Input Observation");
            Validation.notNull(observation2, "Null Output Observation");
            instanceList.addThruPipe(new Instance(observation, observation2.getVariableSpace().findFirst().map((v0) -> {
                return v0.getName();
            }).orElseThrow(), (Object) null, (Object) null));
        }
        this.model = getTrainer().train(instanceList);
        this.labelAlphabet = this.model.getInstancePipe().getTargetAlphabet();
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel, com.gengoai.apollo.ml.model.Model
    public T getFitParameters() {
        return (T) this.parameters;
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public LabelType getLabelType(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (str.equals(this.parameters.output.value())) {
            return LabelType.classificationType(this.labelAlphabet.size());
        }
        throw new IllegalArgumentException("'" + str + "' is not a valid output for this model.");
    }

    protected abstract ClassifierTrainer<?> getTrainer();

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        Labeling labeling = this.model.classify(this.model.getInstancePipe().instanceFrom(new Instance(observation, "", (Object) null, (Object) null))).getLabeling();
        double[] dArr = new double[this.labelAlphabet.size()];
        for (int i = 0; i < this.labelAlphabet.size(); i++) {
            dArr[i] = labeling.value(i);
        }
        return new Classification(NDArrayFactory.ND.rowVector(dArr), new MalletEncoder(this.labelAlphabet));
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        dataSet.updateMetadata((String) this.parameters.output.value(), observationMetadata -> {
            observationMetadata.setDimension(-1L);
            observationMetadata.setEncoder(null);
            observationMetadata.setType(Classification.class);
        });
    }
}
