package sklearn.neighbors;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;

/* loaded from: input_file:sklearn/neighbors/KNeighborsUtil.class */
public class KNeighborsUtil {
    private static final String VARIABLE_ID = "id";

    private KNeighborsUtil() {
    }

    public static List<Integer> createRange(int i, int i2) {
        ArrayList arrayList = new ArrayList();
        for (int i3 = i; i3 < i2; i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        return arrayList;
    }

    public static <E extends Estimator & HasTrainingData> int getNumberOfOutputs(E e) {
        int[] yShape = e.getYShape();
        if (yShape.length == 1) {
            return 1;
        }
        if (yShape.length == 2) {
            return yShape[1];
        }
        throw new IllegalArgumentException();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <E extends Estimator & HasMetric & HasNumberOfNeighbors & HasTrainingData> NearestNeighborModel encodeNeighbors(E e, MiningFunction miningFunction, int i, int i2, Schema schema) {
        Output createOutput;
        int numberOfNeighbors = e.getNumberOfNeighbors();
        int numberOfOutputs = e.getNumberOfOutputs();
        List<? extends Number> fitX = e.getFitX();
        List<?> id = e.getId();
        List<? extends Number> y = e.getY();
        if (id != null) {
            ClassDictUtil.checkSize(i, new Collection[]{id});
        }
        if (y != null) {
            ClassDictUtil.checkSize(i * numberOfOutputs, new Collection[]{y});
        }
        MultiLabel label = schema.getLabel();
        List features = schema.getFeatures();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        InstanceFields instanceFields = new InstanceFields();
        if (id != null) {
            InstanceField column = new InstanceField(VARIABLE_ID).setColumn("data:id");
            instanceFields.addInstanceFields(new InstanceField[]{column});
            linkedHashMap.put(column.getColumn(), id);
        }
        if (numberOfOutputs == 0) {
            if (label != null) {
                throw new IllegalArgumentException();
            }
        } else if (numberOfOutputs == 1) {
            ScalarLabel scalarLabel = (ScalarLabel) label;
            if (scalarLabel != null) {
                InstanceField column2 = new InstanceField(scalarLabel.getName()).setColumn("data:y");
                instanceFields.addInstanceFields(new InstanceField[]{column2});
                linkedHashMap.put(column2.getColumn(), translateValues(scalarLabel, y));
            }
        } else {
            if (numberOfOutputs < 2) {
                throw new IllegalArgumentException();
            }
            List labels = label.getLabels();
            for (int i3 = 0; i3 < labels.size(); i3++) {
                ScalarLabel scalarLabel2 = (ScalarLabel) labels.get(i3);
                if (scalarLabel2 != null) {
                    InstanceField column3 = new InstanceField(scalarLabel2.getName()).setColumn("data:y" + String.valueOf(i3 + 1));
                    instanceFields.addInstanceFields(new InstanceField[]{column3});
                    linkedHashMap.put(column3.getColumn(), translateValues(scalarLabel2, CMatrixUtil.getColumn(y, i, numberOfOutputs, i3)));
                }
            }
        }
        DataType dataType = e.getDataType();
        KNNInputs kNNInputs = new KNNInputs();
        for (int i4 = 0; i4 < features.size(); i4++) {
            String name = ((Feature) features.get(i4)).toContinuousFeature(dataType).getName();
            InstanceField column4 = new InstanceField(name).setColumn("data:x" + String.valueOf(i4 + 1));
            instanceFields.addInstanceFields(new InstanceField[]{column4});
            kNNInputs.addKNNInputs(new KNNInput[]{new KNNInput(name)});
            linkedHashMap.put(column4.getColumn(), CMatrixUtil.getColumn(fitX, i, i2, i4));
        }
        TrainingInstances transformed = new TrainingInstances(instanceFields, PMMLUtil.createInlineTable(linkedHashMap)).setTransformed(true);
        ComparisonMeasure encodeComparisonMeasure = encodeComparisonMeasure(e);
        if (numberOfOutputs == 0 || numberOfOutputs == 1) {
            createOutput = createOutput(numberOfNeighbors, null);
        } else {
            if (numberOfOutputs < 2) {
                throw new IllegalArgumentException();
            }
            createOutput = createOutput(numberOfNeighbors, label.getLabel(0));
        }
        NearestNeighborModel output = new NearestNeighborModel(miningFunction, Integer.valueOf(numberOfNeighbors), ModelUtil.createMiningSchema(schema.getLabel()), transformed, encodeComparisonMeasure, kNNInputs).setOutput(createOutput);
        if (id != null) {
            output.setInstanceIdVariable(VARIABLE_ID);
        }
        return output;
    }

    private static <E extends Estimator & HasMetric> ComparisonMeasure encodeComparisonMeasure(E e) {
        return new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, encodeMeasure(e)).setCompareFunction(CompareFunction.ABS_DIFF);
    }

    private static <E extends Estimator & HasMetric> Measure encodeMeasure(E e) {
        String metric = e.getMetric();
        int p = e.getP();
        boolean z = -1;
        switch (metric.hashCode()) {
            case -278389504:
                if (metric.equals("manhattan")) {
                    z = true;
                    break;
                }
                break;
            case 741620446:
                if (metric.equals("euclidean")) {
                    z = false;
                    break;
                }
                break;
            case 878096752:
                if (metric.equals("minkowski")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new Euclidean();
            case true:
                return new CityBlock();
            case true:
                switch (p) {
                    case 1:
                        return new CityBlock();
                    case 2:
                        return new Euclidean();
                    default:
                        return new Minkowski(Integer.valueOf(p));
                }
            default:
                throw new IllegalArgumentException(metric);
        }
    }

    private static Output createOutput(int i, ScalarLabel scalarLabel) {
        if (i == 1) {
            return null;
        }
        List createNeighborFields = ModelUtil.createNeighborFields(i);
        if (scalarLabel != null) {
            Iterator it = createNeighborFields.iterator();
            while (it.hasNext()) {
                ((OutputField) it.next()).setTargetField(scalarLabel.getName());
            }
        }
        Output output = new Output();
        output.getOutputFields().addAll(createNeighborFields);
        return output;
    }

    private static List<?> translateValues(ScalarLabel scalarLabel, List<? extends Number> list) {
        if (scalarLabel instanceof ContinuousLabel) {
            return list;
        }
        if (!(scalarLabel instanceof CategoricalLabel)) {
            throw new IllegalArgumentException();
        }
        final CategoricalLabel categoricalLabel = (CategoricalLabel) scalarLabel;
        return Lists.transform(list, new Function<Number, Object>() { // from class: sklearn.neighbors.KNeighborsUtil.1
            public Object apply(Number number) {
                return categoricalLabel.getValue(ValueUtil.asInt(number));
            }
        });
    }
}
