package com.github.chen0040.mlp.ann.regression;

import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.mlp.enums.WeightUpdateMode;
import com.github.chen0040.mlp.functions.Identity;
import com.github.chen0040.mlp.functions.Sigmoid;
import com.github.chen0040.mlp.functions.TransferFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/github/chen0040/mlp/ann/regression/MLPRegression.class */
public class MLPRegression {
    private MLPWithNumericOutput mlp;
    private int epoches;
    private double learningRate;
    private int miniBatchSize = 50;
    protected WeightUpdateMode weightUpdateMode = WeightUpdateMode.StochasticGradientDescend;
    private TransferFunction hiddenLayerTransfer = new Sigmoid();
    private TransferFunction outputLayerTransfer = new Identity();
    private List<Integer> hiddenLayers = new ArrayList();

    public MLPRegression() {
        this.epoches = 1000;
        this.learningRate = 0.2d;
        this.epoches = 1000;
        this.learningRate = 0.2d;
        this.hiddenLayers.add(6);
    }

    public List<Integer> getHiddenLayers() {
        return this.hiddenLayers;
    }

    public void setHiddenLayers(int... iArr) {
        this.hiddenLayers = new ArrayList();
        for (int i : iArr) {
            this.hiddenLayers.add(Integer.valueOf(i));
        }
    }

    public double transform(DataRow dataRow) {
        return this.mlp.transform(dataRow)[0];
    }

    public void fit(DataFrame dataFrame) {
        this.mlp = new MLPWithNumericOutput();
        this.mlp.setNormalizeOutputs(true);
        this.mlp.setMiniBatchSize(this.miniBatchSize);
        this.mlp.setWeightUpdateMode(this.weightUpdateMode);
        int length = dataFrame.row(0).toArray().length;
        this.mlp.setLearningRate(this.learningRate);
        this.mlp.createInputLayer(length);
        Iterator<Integer> it = this.hiddenLayers.iterator();
        while (it.hasNext()) {
            this.mlp.addHiddenLayer(it.next().intValue(), this.hiddenLayerTransfer);
        }
        this.mlp.createOutputLayer(1);
        this.mlp.outputLayer.setTransfer(this.outputLayerTransfer);
        this.mlp.train(dataFrame, this.epoches);
    }

    public int getEpoches() {
        return this.epoches;
    }

    public void setEpoches(int i) {
        this.epoches = i;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public void setWeightUpdateMode(WeightUpdateMode weightUpdateMode) {
        this.weightUpdateMode = weightUpdateMode;
    }

    public TransferFunction getHiddenLayerTransfer() {
        return this.hiddenLayerTransfer;
    }

    public void setHiddenLayerTransfer(TransferFunction transferFunction) {
        this.hiddenLayerTransfer = transferFunction;
    }

    public TransferFunction getOutputLayerTransfer() {
        return this.outputLayerTransfer;
    }

    public void setOutputLayerTransfer(TransferFunction transferFunction) {
        this.outputLayerTransfer = transferFunction;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
