package com.gengoai.apollo.ml.transform;

import com.gengoai.Validation;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.stream.Streams;
import java.util.DoubleSummaryStatistics;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/BinTransform.class */
public class BinTransform extends PerPrefixTransform<BinTransform> {
    private static final long serialVersionUID = 1;
    private final int numberOfBins;
    private final Map<String, double[]> prefixBinMap = new HashMap();
    private final boolean includeSuffix;

    public BinTransform(int i, boolean z) {
        Validation.checkArgument(i > 0, "Number of bins must be > 0.");
        this.numberOfBins = i;
        this.includeSuffix = z;
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected void fit(@NonNull String str, @NonNull Iterable<Variable> iterable) {
        if (str == null) {
            throw new NullPointerException("prefix is marked non-null but is null");
        }
        if (iterable == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        double[] dArr = new double[this.numberOfBins];
        DoubleSummaryStatistics doubleSummaryStatistics = (DoubleSummaryStatistics) Streams.asStream(iterable).collect(Collectors.summarizingDouble((v0) -> {
            return v0.getValue();
        }));
        double max = doubleSummaryStatistics.getMax();
        double min = doubleSummaryStatistics.getMin();
        double d = (max - min) / this.numberOfBins;
        double d2 = min;
        for (int i = 0; i < dArr.length; i++) {
            d2 += d;
            dArr[i] = d2;
        }
        this.prefixBinMap.put(str, dArr);
    }

    private int getBin(String str, double d) {
        int i = 0;
        double[] dArr = this.prefixBinMap.get(str);
        while (i < dArr.length - 1 && d >= dArr[i]) {
            i++;
        }
        return i;
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected void reset() {
        this.prefixBinMap.clear();
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected Variable transform(@NonNull Variable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        return this.includeSuffix ? Variable.binary(variable.getPrefix(), variable.getSuffix() + "-Bin[" + getBin(variable.getPrefix(), variable.getValue()) + "]") : Variable.binary(variable.getPrefix(), "Bin[" + getBin(variable.getPrefix(), variable.getValue()) + "]");
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
    }
}
