package org.jpmml.converter;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.True;

/* loaded from: input_file:org/jpmml/converter/MiningModelUtil.class */
public class MiningModelUtil {
    private static final Function<Model, FieldName> MODEL_PREDICTION = new Function<Model, FieldName>() { // from class: org.jpmml.converter.MiningModelUtil.1
        public FieldName apply(Model model) {
            Output output = model.getOutput();
            if (output == null || !output.hasOutputFields()) {
                throw new IllegalArgumentException();
            }
            return ((OutputField) Iterables.getLast(output.getOutputFields())).getName();
        }
    };

    private MiningModelUtil() {
    }

    public static MiningModel createRegression(Schema schema, Model model) {
        return createRegression(schema.getTargetField(), schema.getActiveFields(), model);
    }

    public static MiningModel createRegression(FieldName fieldName, List<FieldName> list, Model model) {
        FieldName fieldName2 = (FieldName) MODEL_PREDICTION.apply(model);
        return createModelChain(fieldName, list, Arrays.asList(model, new RegressionModel(MiningFunctionType.REGRESSION, ModelUtil.createMiningSchema(fieldName, (List<FieldName>) Collections.singletonList(fieldName2)), (List) null).addRegressionTables(new RegressionTable[]{new RegressionTable(0.0d).addNumericPredictors(new NumericPredictor[]{new NumericPredictor(fieldName2, 1.0d)})})));
    }

    public static MiningModel createBinaryLogisticClassification(Schema schema, Model model, double d, boolean z) {
        return createBinaryLogisticClassification(schema.getTargetField(), schema.getTargetCategories(), schema.getActiveFields(), model, d, z);
    }

    public static MiningModel createBinaryLogisticClassification(FieldName fieldName, List<String> list, List<FieldName> list2, Model model, double d, boolean z) {
        if (list.size() != 2) {
            throw new IllegalArgumentException();
        }
        FieldName fieldName2 = (FieldName) MODEL_PREDICTION.apply(model);
        return createModelChain(fieldName, list2, Arrays.asList(model, new RegressionModel(MiningFunctionType.CLASSIFICATION, ModelUtil.createMiningSchema(fieldName, (List<FieldName>) Collections.singletonList(fieldName2)), (List) null).setNormalizationMethod(RegressionNormalizationMethodType.SOFTMAX).addRegressionTables(new RegressionTable[]{new RegressionTable(0.0d).setTargetCategory(list.get(0)).addNumericPredictors(new NumericPredictor[]{new NumericPredictor(fieldName2, d)}), new RegressionTable(0.0d).setTargetCategory(list.get(1))}).setOutput(z ? new Output(ModelUtil.createProbabilityFields(list)) : null)));
    }

    public static MiningModel createClassification(Schema schema, List<? extends Model> list, RegressionNormalizationMethodType regressionNormalizationMethodType, boolean z) {
        return createClassification(schema.getTargetField(), schema.getTargetCategories(), schema.getActiveFields(), list, regressionNormalizationMethodType, z);
    }

    public static MiningModel createClassification(FieldName fieldName, List<String> list, List<FieldName> list2, List<? extends Model> list3, RegressionNormalizationMethodType regressionNormalizationMethodType, boolean z) {
        if (list.size() != list3.size()) {
            throw new IllegalArgumentException();
        }
        List transform = Lists.transform(list3, MODEL_PREDICTION);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new RegressionTable(0.0d).setTargetCategory(list.get(i)).addNumericPredictors(new NumericPredictor[]{new NumericPredictor((FieldName) transform.get(i), 1.0d)}));
        }
        RegressionModel output = new RegressionModel(MiningFunctionType.CLASSIFICATION, ModelUtil.createMiningSchema(fieldName, (List<FieldName>) transform), arrayList).setNormalizationMethod(regressionNormalizationMethodType).setOutput(z ? new Output(ModelUtil.createProbabilityFields(list)) : null);
        ArrayList arrayList2 = new ArrayList(list3);
        arrayList2.add(output);
        return createModelChain(fieldName, list2, arrayList2);
    }

    private static MiningModel createModelChain(FieldName fieldName, List<FieldName> list, List<? extends Model> list2) {
        Segmentation createSegmentation = createSegmentation(MultipleModelMethodType.MODEL_CHAIN, list2);
        Model model = (Model) Iterables.getLast(list2);
        return new MiningModel(model.getFunctionName(), ModelUtil.createMiningSchema(fieldName, list)).setSegmentation(createSegmentation);
    }

    public static Segmentation createSegmentation(MultipleModelMethodType multipleModelMethodType, List<? extends Model> list) {
        return createSegmentation(multipleModelMethodType, list, null);
    }

    public static Segmentation createSegmentation(MultipleModelMethodType multipleModelMethodType, List<? extends Model> list, List<? extends Number> list2) {
        if (list2 != null && list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Model model = list.get(i);
            Number number = list2 != null ? list2.get(i) : null;
            Segment model2 = new Segment().setId(String.valueOf(i + 1)).setPredicate(new True()).setModel(model);
            if (number != null && !ValueUtil.isOne(number)) {
                model2.setWeight(ValueUtil.asDouble(number));
            }
            arrayList.add(model2);
        }
        return new Segmentation(multipleModelMethodType, arrayList);
    }
}
