package ai.libs.jaicore.ml.classification.multilabel;

import ai.libs.jaicore.ml.core.evaluation.Prediction;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multilabel/MultiLabelClassification.class */
public class MultiLabelClassification extends Prediction implements IMultiLabelClassification {
    private static final double DEFAULT_THRESHOLD = 0.5d;
    private double[] threshold;

    public MultiLabelClassification(double[] dArr) {
        this(dArr, DEFAULT_THRESHOLD);
    }

    public MultiLabelClassification(double[] dArr, double d) {
        this(dArr, IntStream.range(0, dArr.length).mapToDouble(i -> {
            return d;
        }).toArray());
    }

    public MultiLabelClassification(double[] dArr, double[] dArr2) {
        super(dArr);
        this.threshold = dArr2;
    }

    @Override // ai.libs.jaicore.ml.core.evaluation.Prediction
    public double[] getPrediction() {
        return (double[]) super.getPrediction();
    }

    public int[] getThresholdedPrediction() {
        return IntStream.range(0, getPrediction().length).map(i -> {
            return getPrediction()[i] >= this.threshold[i] ? 1 : 0;
        }).toArray();
    }

    public int[] getPrediction(double d) {
        return IntStream.range(0, getPrediction().length).map(i -> {
            return getPrediction()[i] >= d ? 1 : 0;
        }).toArray();
    }

    public int[] getPrediction(double[] dArr) {
        return IntStream.range(0, getPrediction().length).map(i -> {
            return getPrediction()[i] >= dArr[i] ? 1 : 0;
        }).toArray();
    }

    public int[] getRelevantLabels(double d) {
        return IntStream.range(0, getPrediction().length).filter(i -> {
            return getPrediction()[i] >= d;
        }).toArray();
    }

    public int[] getIrrelevantLabels(double d) {
        return IntStream.range(0, getPrediction().length).filter(i -> {
            return getPrediction()[i] < d;
        }).toArray();
    }
}
