package com.gengoai.apollo.ml.model;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.observation.Classification;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableSequence;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/LabelType.class */
public enum LabelType {
    BinaryClassification { // from class: com.gengoai.apollo.ml.model.LabelType.1
        @Override // com.gengoai.apollo.ml.model.LabelType
        public Observation transform(Encoder encoder, @NonNull Observation observation) {
            if (observation == null) {
                throw new NullPointerException("observation is marked non-null but is null");
            }
            return LabelType.transformToClassification(encoder, observation);
        }

        @Override // com.gengoai.apollo.ml.model.LabelType
        public Class<? extends Observation> getObservationClass() {
            return Classification.class;
        }
    },
    MultiClassClassification { // from class: com.gengoai.apollo.ml.model.LabelType.2
        @Override // com.gengoai.apollo.ml.model.LabelType
        public Observation transform(Encoder encoder, @NonNull Observation observation) {
            if (observation == null) {
                throw new NullPointerException("observation is marked non-null but is null");
            }
            return LabelType.transformToClassification(encoder, observation);
        }

        @Override // com.gengoai.apollo.ml.model.LabelType
        public Class<? extends Observation> getObservationClass() {
            return Classification.class;
        }
    },
    Sequence { // from class: com.gengoai.apollo.ml.model.LabelType.3
        @Override // com.gengoai.apollo.ml.model.LabelType
        public Observation transform(Encoder encoder, @NonNull Observation observation) {
            if (observation == null) {
                throw new NullPointerException("observation is marked non-null but is null");
            }
            if (observation.isSequence()) {
                return observation;
            }
            if (!observation.isNDArray()) {
                throw new IllegalArgumentException("Expecting Sequence or NDArray, but found: " + observation.getClass());
            }
            VariableSequence variableSequence = new VariableSequence();
            NDArray asNDArray = observation.asNDArray();
            for (int i = 0; i < asNDArray.rows(); i++) {
                if (asNDArray.columns() == 1) {
                    variableSequence.add(Variable.binary(encoder.decode(asNDArray.get(0L))));
                } else {
                    long argmax = asNDArray.argmax();
                    variableSequence.add(Variable.real(encoder.decode(argmax), asNDArray.get(argmax)));
                }
            }
            return variableSequence;
        }

        @Override // com.gengoai.apollo.ml.model.LabelType
        public Class<? extends Observation> getObservationClass() {
            return Sequence.class;
        }
    },
    NDArray { // from class: com.gengoai.apollo.ml.model.LabelType.4
        @Override // com.gengoai.apollo.ml.model.LabelType
        public Observation transform(Encoder encoder, @NonNull Observation observation) {
            if (observation == null) {
                throw new NullPointerException("observation is marked non-null but is null");
            }
            if (observation.isNDArray()) {
                return observation;
            }
            throw new IllegalArgumentException("Expecting NDArray, but found: " + observation.getClass());
        }

        @Override // com.gengoai.apollo.ml.model.LabelType
        public Class<? extends Observation> getObservationClass() {
            return NDArray.class;
        }
    };

    public static LabelType classificationType(int i) {
        return i > 2 ? MultiClassClassification : BinaryClassification;
    }

    private static Observation transformToClassification(Encoder encoder, Observation observation) {
        if (observation.isClassification()) {
            return observation;
        }
        if (observation.isNDArray()) {
            return new Classification(observation.asNDArray(), encoder);
        }
        throw new IllegalArgumentException("Expecting Classification or NDArray, but found: " + observation.getClass());
    }

    public abstract Class<? extends Observation> getObservationClass();

    public abstract Observation transform(Encoder encoder, @NonNull Observation observation);
}
