package org.jpmml.converter.general_regression;

import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.general_regression.CovariateList;
import org.dmg.pmml.general_regression.FactorList;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.general_regression.PPMatrix;
import org.dmg.pmml.general_regression.ParamMatrix;
import org.dmg.pmml.general_regression.Parameter;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.general_regression.Predictor;
import org.dmg.pmml.general_regression.PredictorList;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/converter/general_regression/GeneralRegressionModelUtil.class */
public class GeneralRegressionModelUtil {
    private GeneralRegressionModelUtil() {
    }

    public static GeneralRegressionModel encodeRegressionTable(GeneralRegressionModel generalRegressionModel, List<Feature> list, Double d, List<Double> list2, String str) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        ParameterList parameterList = generalRegressionModel.getParameterList();
        if (parameterList == null) {
            parameterList = new ParameterList();
            generalRegressionModel.setParameterList(parameterList);
        }
        PPMatrix pPMatrix = generalRegressionModel.getPPMatrix();
        if (pPMatrix == null) {
            pPMatrix = new PPMatrix();
            generalRegressionModel.setPPMatrix(pPMatrix);
        }
        ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix();
        if (paramMatrix == null) {
            paramMatrix = new ParamMatrix();
            generalRegressionModel.setParamMatrix(paramMatrix);
        }
        int size = parameterList.getParameters().size();
        if (d != null && !ValueUtil.isZero(d)) {
            Parameter label = new Parameter("p" + String.valueOf(size)).setLabel("(intercept)");
            parameterList.addParameters(new Parameter[]{label});
            size++;
            paramMatrix.addPCells(new PCell[]{new PCell(label.getName(), d.doubleValue()).setTargetCategory(str)});
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            Double d2 = list2.get(i);
            if (d2 != null && !d2.isNaN()) {
                Parameter parameter = new Parameter("p" + String.valueOf(size));
                parameterList.addParameters(new Parameter[]{parameter});
                size++;
                createPPCells(feature, parameter, pPMatrix, linkedHashSet, linkedHashSet2);
                paramMatrix.addPCells(new PCell[]{new PCell(parameter.getName(), d2.doubleValue()).setTargetCategory(str)});
            }
        }
        if (linkedHashSet.size() > 0) {
            CovariateList covariateList = generalRegressionModel.getCovariateList();
            if (covariateList == null) {
                covariateList = new CovariateList();
                generalRegressionModel.setCovariateList(covariateList);
            }
            createPredictors(covariateList, linkedHashSet);
        }
        if (linkedHashSet2.size() > 0) {
            FactorList factorList = generalRegressionModel.getFactorList();
            if (factorList == null) {
                factorList = new FactorList();
                generalRegressionModel.setFactorList(factorList);
            }
            createPredictors(factorList, linkedHashSet2);
        }
        return generalRegressionModel;
    }

    private static void createPPCells(Feature feature, Parameter parameter, PPMatrix pPMatrix, Set<FieldName> set, Set<FieldName> set2) {
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = (BinaryFeature) feature;
            PPCell pPCell = new PPCell(binaryFeature.getValue(), binaryFeature.getName(), parameter.getName());
            pPMatrix.addPPCells(new PPCell[]{pPCell});
            set2.add(pPCell.getPredictorName());
            return;
        }
        if (feature instanceof InteractionFeature) {
            Iterator<Feature> it = ((InteractionFeature) feature).getInputFeatures().iterator();
            while (it.hasNext()) {
                createPPCells(it.next(), parameter, pPMatrix, set, set2);
            }
        } else {
            if (!(feature instanceof ContinuousFeature)) {
                throw new IllegalArgumentException();
            }
            PPCell pPCell2 = new PPCell("1", ((ContinuousFeature) feature).getName(), parameter.getName());
            pPMatrix.addPPCells(new PPCell[]{pPCell2});
            set.add(pPCell2.getPredictorName());
        }
    }

    private static void createPredictors(PredictorList predictorList, Set<FieldName> set) {
        LinkedHashSet linkedHashSet = new LinkedHashSet(set);
        Iterator it = predictorList.getPredictors().iterator();
        while (it.hasNext()) {
            linkedHashSet.remove(((Predictor) it.next()).getName());
        }
        if (linkedHashSet.isEmpty()) {
            return;
        }
        Iterator it2 = linkedHashSet.iterator();
        while (it2.hasNext()) {
            predictorList.addPredictors(new Predictor[]{new Predictor((FieldName) it2.next())});
        }
    }
}
