package com.gengoai.apollo.ml.evaluation;

import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableSequence;
import com.gengoai.conversion.Cast;
import java.io.PrintStream;
import java.io.Serializable;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/evaluation/PerInstanceEvaluation.class */
public class PerInstanceEvaluation implements SequenceLabelerEvaluation, Serializable {
    private static final long serialVersionUID = 1;
    private final MultiClassEvaluation eval;
    private final String outputName;

    public PerInstanceEvaluation(String str) {
        this.outputName = str;
        this.eval = new MultiClassEvaluation(str);
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void evaluate(@NonNull Model model, @NonNull DataSet dataSet) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        dataSet.forEach(datum -> {
            VariableSequence variableSequence = toVariableSequence(datum.get(this.outputName));
            VariableSequence variableSequence2 = toVariableSequence(model.transform(datum).get(this.outputName));
            for (int i = 0; i < variableSequence.size(); i++) {
                this.eval.entry(variableSequence.get(i).getName(), variableSequence2.get(i).getName());
            }
        });
    }

    @Override // com.gengoai.apollo.ml.evaluation.SequenceLabelerEvaluation
    public void merge(@NonNull SequenceLabelerEvaluation sequenceLabelerEvaluation) {
        if (sequenceLabelerEvaluation == null) {
            throw new NullPointerException("evaluation is marked non-null but is null");
        }
        Validation.checkArgument(sequenceLabelerEvaluation instanceof PerInstanceEvaluation);
        this.eval.merge(((PerInstanceEvaluation) Cast.as(sequenceLabelerEvaluation)).eval);
    }

    @Override // com.gengoai.apollo.ml.evaluation.SequenceLabelerEvaluation
    public void output(@NonNull PrintStream printStream, boolean z) {
        if (printStream == null) {
            throw new NullPointerException("printStream is marked non-null but is null");
        }
        this.eval.output(printStream, z);
    }

    private VariableSequence toVariableSequence(Observation observation) {
        if (observation instanceof VariableSequence) {
            return (VariableSequence) Cast.as(observation);
        }
        if (!observation.isNDArray()) {
            throw new IllegalArgumentException(observation.getClass() + " is not supported");
        }
        VariableSequence variableSequence = new VariableSequence();
        NDArray asNDArray = observation.asNDArray();
        for (int i = 0; i < asNDArray.rows(); i++) {
            NDArray row = asNDArray.getRow(i);
            if (asNDArray.columns() == 1) {
                variableSequence.add(Variable.binary(Double.toString(row.get(0L))));
            } else {
                variableSequence.add(Variable.binary(Double.toString(row.argmax())));
            }
        }
        return variableSequence;
    }
}
