package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;

/* loaded from: input_file:org/jpmml/rexp/NNConverter.class */
public class NNConverter extends ModelConverter<RGenericVector> {
    public NNConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        RGenericVector genericElement = ((RGenericVector) getObject()).getGenericElement("model.list");
        RStringVector stringElement = genericElement.getStringElement("response");
        RStringVector stringElement2 = genericElement.getStringElement("variables");
        rExpEncoder.setLabel(rExpEncoder.createDataField(stringElement.asScalar(), OpType.CONTINUOUS, DataType.DOUBLE));
        for (int i = 0; i < stringElement2.size(); i++) {
            rExpEncoder.addFeature((Field<?>) rExpEncoder.createDataField(stringElement2.getValue(i), OpType.CONTINUOUS, DataType.DOUBLE));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel */
    public Model mo0encodeModel(Schema schema) {
        NeuralNetwork.ActivationFunction activationFunction;
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RExp element = rGenericVector.getElement("act.fct");
        RBooleanVector booleanElement = rGenericVector.getBooleanElement("linear.output");
        RGenericVector genericElement = rGenericVector.getGenericElement("weights");
        RStringVector stringAttribute = element.getStringAttribute("type");
        RGenericVector genericValue = genericElement.getGenericValue(0);
        NeuralNetwork.ActivationFunction activationFunction2 = NeuralNetwork.ActivationFunction.LOGISTIC;
        String asScalar = stringAttribute.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case 3552487:
                if (asScalar.equals("tanh")) {
                    z = true;
                    break;
                }
                break;
            case 2022928992:
                if (asScalar.equals("logistic")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                activationFunction = NeuralNetwork.ActivationFunction.LOGISTIC;
                break;
            case true:
                activationFunction = NeuralNetwork.ActivationFunction.TANH;
                break;
            default:
                throw new IllegalArgumentException();
        }
        ContinuousLabel label = schema.getLabel();
        NeuralInputs createNeuralInputs = NeuralNetworkUtil.createNeuralInputs(schema.getFeatures(), DataType.DOUBLE);
        ArrayList arrayList = new ArrayList();
        List neuralInputs = createNeuralInputs.getNeuralInputs();
        int i = 0;
        while (i < genericValue.size()) {
            boolean z2 = i < genericValue.size() - 1;
            NeuralLayer neuralLayer = new NeuralLayer();
            if (z2 || (booleanElement != null && !booleanElement.asScalar().booleanValue())) {
                neuralLayer.setActivationFunction(activationFunction);
            }
            RDoubleVector doubleValue = genericValue.getDoubleValue(i);
            RIntegerVector dim = doubleValue.dim();
            int intValue = dim.getValue(0).intValue();
            int intValue2 = dim.getValue(1).intValue();
            for (int i2 = 0; i2 < intValue2; i2++) {
                List column = FortranMatrixUtil.getColumn(doubleValue.getValues(), intValue, intValue2, i2);
                neuralLayer.addNeurons(new Neuron[]{NeuralNetworkUtil.createNeuron(neuralInputs, column.subList(1, column.size()), (Number) column.get(0)).setId(z2 ? "hidden/" + String.valueOf(i) + "/" + String.valueOf(i2) : "output/" + String.valueOf(i2))});
            }
            arrayList.add(neuralLayer);
            neuralInputs = neuralLayer.getNeurons();
            i++;
        }
        return new NeuralNetwork(MiningFunction.REGRESSION, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema(label), createNeuralInputs, arrayList).setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(neuralInputs, label));
    }
}
