package com.github.chen0040.mlp.ann.classifiers;

import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.mlp.ann.MLP;
import java.util.List;
import java.util.function.Supplier;

/* loaded from: input_file:com/github/chen0040/mlp/ann/classifiers/MLPWithLabelOutput.class */
public class MLPWithLabelOutput extends MLP {
    public Supplier<List<String>> classLabelsModel;

    @Override // com.github.chen0040.mlp.ann.MLP
    public boolean isValidTrainingSample(DataRow dataRow) {
        return !dataRow.getCategoricalTargetColumnNames().isEmpty();
    }

    @Override // com.github.chen0040.mlp.ann.MLP
    public double[] getTarget(DataRow dataRow) {
        List<String> list = this.classLabelsModel.get();
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).equals(dataRow.categoricalTarget()) ? 1.0d : 0.0d;
        }
        return dArr;
    }
}
