package javax.visrec.ri.ml.classification;

import deepnetts.data.MLDataItem;
import deepnetts.net.FeedForwardNetwork;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.LossType;
import deepnetts.net.train.BackpropagationTrainer;
import deepnetts.util.Tensor;
import java.util.HashMap;
import java.util.Map;
import javax.visrec.ml.classification.AbstractMultiClassClassifier;
import javax.visrec.ml.data.DataSet;

/* loaded from: input_file:javax/visrec/ri/ml/classification/MultiClassClassifierNetwork.class */
public class MultiClassClassifierNetwork extends AbstractMultiClassClassifier<FeedForwardNetwork, float[], String> {

    /* loaded from: input_file:javax/visrec/ri/ml/classification/MultiClassClassifierNetwork$Builder.class */
    public static class Builder implements javax.visrec.util.Builder<MultiClassClassifierNetwork> {
        private MultiClassClassifierNetwork building = new MultiClassClassifierNetwork();
        private float learningRate = 0.01f;
        private float maxError = 0.03f;
        private long maxEpochs = Long.MAX_VALUE;
        private int inputsNum;
        private int outputsNum;
        private int[] hiddenLayers;
        private DataSet<? extends MLDataItem> trainingSet;

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public MultiClassClassifierNetwork m6build() {
            FeedForwardNetwork.Builder addInputLayer = FeedForwardNetwork.builder().addInputLayer(this.inputsNum);
            for (int i : this.hiddenLayers) {
                addInputLayer.addFullyConnectedLayer(i, ActivationType.TANH);
            }
            addInputLayer.addOutputLayer(this.outputsNum, ActivationType.SOFTMAX).lossFunction(LossType.CROSS_ENTROPY).hiddenActivationFunction(ActivationType.TANH);
            FeedForwardNetwork build = addInputLayer.build();
            BackpropagationTrainer backpropagationTrainer = new BackpropagationTrainer(build);
            backpropagationTrainer.setLearningRate(this.learningRate).setMaxError(this.maxError).setMaxEpochs(this.maxEpochs);
            if (this.trainingSet != null) {
                backpropagationTrainer.train(this.trainingSet);
            }
            this.building.setModel(build);
            return this.building;
        }

        public Builder learningRate(float f) {
            this.learningRate = f;
            return this;
        }

        public Builder maxError(float f) {
            this.maxError = f;
            return this;
        }

        public Builder maxEpochs(int i) {
            this.maxEpochs = i;
            return this;
        }

        public Builder inputsNum(int i) {
            this.inputsNum = i;
            return this;
        }

        public Builder outputsNum(int i) {
            this.outputsNum = i;
            return this;
        }

        public Builder hiddenLayers(int... iArr) {
            this.hiddenLayers = iArr;
            return this;
        }

        public Builder trainingSet(DataSet<? extends MLDataItem> dataSet) {
            this.trainingSet = dataSet;
            return this;
        }
    }

    public Map<String, Float> classify(float[] fArr) {
        FeedForwardNetwork feedForwardNetwork = (FeedForwardNetwork) getModel();
        feedForwardNetwork.setInput(Tensor.create(1, fArr.length, fArr));
        float[] output = feedForwardNetwork.getOutput();
        String[] outputLabels = feedForwardNetwork.getOutputLabels();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < output.length; i++) {
            hashMap.put(outputLabels[i], Float.valueOf(output[i]));
        }
        return hashMap;
    }

    public static Builder builder() {
        return new Builder();
    }
}
