package sklearn2pmml.preprocessing;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Interval;
import org.dmg.pmml.OpType;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.TypeUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

/* loaded from: input_file:sklearn2pmml/preprocessing/CutTransformer.class */
public class CutTransformer extends Transformer {
    public CutTransformer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        DataType dataType;
        List<? extends Number> bins = getBins();
        List<?> labels = getLabels();
        Boolean right = getRight();
        Boolean includeLowest = getIncludeLowest();
        ClassDictUtil.checkSize(1, new Collection[]{list});
        if (labels != null) {
            ClassDictUtil.checkSize(bins.size() - 1, new Collection[]{labels});
            dataType = TypeUtil.getDataType(labels, DataType.STRING);
        } else {
            dataType = DataType.INTEGER;
        }
        Feature feature = list.get(0);
        Interval.Closure closure = right.booleanValue() ? Interval.Closure.OPEN_CLOSED : Interval.Closure.CLOSED_OPEN;
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        ArrayList arrayList = new ArrayList();
        Discretize dataType2 = new Discretize(continuousFeature.getName()).setDataType(dataType);
        for (int i = 0; i < bins.size() - 1; i++) {
            Interval rightMargin = new Interval(closure).setLeftMargin(formatMargin(bins.get(i))).setRightMargin(formatMargin(bins.get(i + 1)));
            if (i == 0 && includeLowest.booleanValue() && rightMargin.requireClosure() == Interval.Closure.OPEN_CLOSED) {
                rightMargin.setClosure(Interval.Closure.CLOSED_CLOSED);
            }
            Object valueOf = labels != null ? labels.get(i) : Integer.valueOf(i);
            arrayList.add(valueOf);
            dataType2.addDiscretizeBins(new DiscretizeBin[]{new DiscretizeBin(valueOf, rightMargin)});
        }
        List discretizeBins = dataType2.getDiscretizeBins();
        if (discretizeBins.size() == 1) {
            DiscretizeBin discretizeBin = (DiscretizeBin) discretizeBins.get(0);
            Interval requireInterval = discretizeBin.requireInterval();
            Object requireBinValue = discretizeBin.requireBinValue();
            if (requireInterval.getLeftMargin() == null && requireInterval.getRightMargin() == null) {
                return Collections.singletonList(FeatureUtil.createFeature(skLearnEncoder.createDerivedField(createFieldName("cut", continuousFeature), OpType.CATEGORICAL, dataType, ExpressionUtil.createApply("if", new Expression[]{ExpressionUtil.createApply("isNotMissing", new Expression[]{continuousFeature.ref()}), ExpressionUtil.createConstant((DataType) null, requireBinValue)})), skLearnEncoder));
            }
        }
        DerivedField createDerivedField = skLearnEncoder.createDerivedField(createFieldName("cut", continuousFeature), OpType.CATEGORICAL, dataType, dataType2);
        return Collections.singletonList(labels != null ? new CategoricalFeature(skLearnEncoder, createDerivedField, arrayList) : new IndexFeature(skLearnEncoder, createDerivedField, arrayList));
    }

    public List<? extends Number> getBins() {
        return getListLike("bins", Number.class);
    }

    public List<?> getLabels() {
        Object optionalScalar = getOptionalScalar("labels");
        if (optionalScalar == null || Boolean.FALSE.equals(optionalScalar)) {
            return null;
        }
        return getList("labels");
    }

    public Boolean getRight() {
        return getBoolean("right");
    }

    public Boolean getIncludeLowest() {
        return getBoolean("include_lowest");
    }

    private static Double formatMargin(Number number) {
        double doubleValue = number.doubleValue();
        if (Double.isInfinite(doubleValue)) {
            return null;
        }
        return Double.valueOf(doubleValue);
    }
}
