package com.gengoai.apollo.ml.data.sampling;

import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.InMemoryDataSet;
import com.gengoai.collection.counter.Counter;
import com.gengoai.stream.MStream;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/data/sampling/OverSampling.class */
public class OverSampling extends BaseObservationDataSetSampler implements Serializable {
    private static final long serialVersionUID = 1;

    public OverSampling(@NonNull String str) {
        super(str);
        if (str == null) {
            throw new NullPointerException("observationName is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.data.sampling.DataSetSampler
    public DataSet sample(@NonNull DataSet dataSet) {
        int i;
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        Counter<String> calculateClassDistribution = calculateClassDistribution(dataSet);
        int minimumCount = (int) calculateClassDistribution.minimumCount();
        ArrayList arrayList = new ArrayList();
        for (Object obj : calculateClassDistribution.items()) {
            MStream cache = dataSet.stream().filter(datum -> {
                Stream<R> map = datum.get(getObservationName()).getVariableSpace().map((v0) -> {
                    return v0.getName();
                });
                Objects.requireNonNull(obj);
                return map.anyMatch((v1) -> {
                    return r1.equals(v1);
                });
            }).map((v0) -> {
                return v0.m14copy();
            }).cache();
            int count = (int) cache.count();
            int i2 = 0;
            while (true) {
                i = i2;
                if (i + count >= minimumCount) {
                    break;
                }
                Objects.requireNonNull(arrayList);
                cache.forEach((v1) -> {
                    r1.add(v1);
                });
                i2 = i + count;
            }
            if (i < minimumCount) {
                MStream sample = cache.sample(false, minimumCount - i);
                Objects.requireNonNull(arrayList);
                sample.forEach((v1) -> {
                    r1.add(v1);
                });
            } else if (count == minimumCount) {
                Objects.requireNonNull(arrayList);
                cache.forEach((v1) -> {
                    r1.add(v1);
                });
            }
        }
        return new InMemoryDataSet(arrayList, dataSet.getMetadata(), dataSet.getNDArrayFactory());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -324489683:
                if (implMethodName.equals("lambda$sample$ab2debf9$1")) {
                    z = true;
                    break;
                }
                break;
            case 96417:
                if (implMethodName.equals("add")) {
                    z = false;
                    break;
                }
                break;
            case 3059573:
                if (implMethodName.equals("copy")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("java/util/List") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Z")) {
                    List list = (List) serializedLambda.getCapturedArg(0);
                    return (v1) -> {
                        r0.add(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("java/util/List") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Z")) {
                    List list2 = (List) serializedLambda.getCapturedArg(0);
                    return (v1) -> {
                        r0.add(v1);
                    };
                }
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("java/util/List") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Z")) {
                    List list3 = (List) serializedLambda.getCapturedArg(0);
                    return (v1) -> {
                        r0.add(v1);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializablePredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("test") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/data/sampling/OverSampling") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Lcom/gengoai/apollo/ml/Datum;)Z")) {
                    OverSampling overSampling = (OverSampling) serializedLambda.getCapturedArg(0);
                    Object capturedArg = serializedLambda.getCapturedArg(1);
                    return datum -> {
                        Stream<R> map = datum.get(getObservationName()).getVariableSpace().map((v0) -> {
                            return v0.getName();
                        });
                        Objects.requireNonNull(capturedArg);
                        return map.anyMatch((v1) -> {
                            return r1.equals(v1);
                        });
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/Datum") && serializedLambda.getImplMethodSignature().equals("()Lcom/gengoai/apollo/ml/Datum;")) {
                    return (v0) -> {
                        return v0.m14copy();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
