package com.gengoai.apollo.ml.observation;

import com.gengoai.Copyable;
import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import java.io.Serializable;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/observation/Classification.class */
public final class Classification implements Serializable, Observation {
    private static final long serialVersionUID = 1;
    private final String argMax;
    private final NDArray distribution;
    private Encoder encoder;

    public Classification(@NonNull NDArray nDArray, Encoder encoder) {
        if (nDArray == null) {
            throw new NullPointerException("distribution is marked non-null but is null");
        }
        if (nDArray.shape().isScalar()) {
            this.distribution = NDArrayFactory.DENSE.array(1, 2);
            this.distribution.set(0L, 1.0d - nDArray.scalar());
            this.distribution.set(serialVersionUID, nDArray.scalar());
        } else {
            this.distribution = nDArray.shape().isColumnVector() ? nDArray.T() : nDArray.m1copy();
        }
        this.argMax = encoder != null ? encoder.decode(this.distribution.argmax()) : Long.toString(this.distribution.argmax());
        this.encoder = encoder;
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public Classification asClassification() {
        return this;
    }

    public Counter<String> asCounter() {
        Validation.notNull(this.encoder, "No Encoder was provided");
        Counter<String> newCounter = Counters.newCounter(new String[0]);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= this.distribution.length()) {
                return newCounter;
            }
            newCounter.set(this.encoder.decode(j2), this.distribution.get((int) j2));
            j = j2 + serialVersionUID;
        }
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public NDArray asNDArray() {
        return this.distribution;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public Classification m52copy() {
        return (Classification) Copyable.deepCopy(this);
    }

    public NDArray distribution() {
        return this.distribution;
    }

    public String getResult() {
        return this.argMax;
    }

    public double getScore(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("label is marked non-null but is null");
        }
        Validation.notNull(this.encoder, "No Encoder was provided");
        return this.distribution.get(this.encoder.encode(str));
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public Stream<Variable> getVariableSpace() {
        return Stream.empty();
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public boolean isClassification() {
        return true;
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public void mapVariables(@NonNull Function<Variable, Variable> function) {
        if (function != null) {
            throw new UnsupportedOperationException("Classification does not support mapping.");
        }
        throw new NullPointerException("mapper is marked non-null but is null");
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public void removeVariables(@NonNull Predicate<Variable> predicate) {
        if (predicate != null) {
            throw new UnsupportedOperationException("Classification does not support filtering.");
        }
        throw new NullPointerException("filter is marked non-null but is null");
    }

    public String toString() {
        return this.encoder == null ? "Classification{" + this.distribution + "}" : "Classification{" + asCounter() + "}";
    }

    @Override // com.gengoai.apollo.ml.observation.Observation
    public void updateVariables(@NonNull Consumer<Variable> consumer) {
        if (consumer != null) {
            throw new UnsupportedOperationException("Classification does not support updating.");
        }
        throw new NullPointerException("updater is marked non-null but is null");
    }
}
