package com.gengoai.apollo.ml.transform;

import com.gengoai.apollo.math.statistics.measure.ContingencyTable;
import com.gengoai.apollo.math.statistics.measure.ContingencyTableCalculator;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.HashMapMultiCounter;
import com.gengoai.collection.counter.MultiCounter;
import com.gengoai.stream.MCounterAccumulator;
import com.gengoai.stream.MMultiCounterAccumulator;
import com.gengoai.stream.MStream;
import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/ContingencyFeatureSelection.class */
public class ContingencyFeatureSelection extends AbstractSingleSourceTransform<ContingencyFeatureSelection> {
    private static final long serialVersionUID = 1;

    @NonNull
    private final String labelSource;
    private final int numFeaturesPerClass;
    private final double threshold;

    @NonNull
    private final ContingencyTableCalculator calculator;

    public ContingencyFeatureSelection(@NonNull String str, int i, double d, @NonNull ContingencyTableCalculator contingencyTableCalculator) {
        if (str == null) {
            throw new NullPointerException("labelSource is marked non-null but is null");
        }
        if (contingencyTableCalculator == null) {
            throw new NullPointerException("calculator is marked non-null but is null");
        }
        this.labelSource = str;
        this.numFeaturesPerClass = i;
        this.threshold = d;
        this.calculator = contingencyTableCalculator;
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void fit(@NonNull MStream<Observation> mStream) {
        if (mStream == null) {
            throw new NullPointerException("observations is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform, com.gengoai.apollo.ml.transform.Transform
    public DataSet fitAndTransform(DataSet dataSet) {
        HashSet hashSet = new HashSet();
        MCounterAccumulator counterAccumulator = dataSet.getType().getStreamingContext().counterAccumulator();
        MMultiCounterAccumulator multiCounterAccumulator = dataSet.getType().getStreamingContext().multiCounterAccumulator();
        dataSet.parallelStream().forEach(datum -> {
            String name = datum.get(this.labelSource).asVariable().getName();
            counterAccumulator.increment(name, 1.0d);
            HashMapMultiCounter hashMapMultiCounter = new HashMapMultiCounter();
            datum.get(this.input).getVariableSpace().forEach(variable -> {
                hashMapMultiCounter.increment(variable.getName(), name);
            });
            multiCounterAccumulator.merge(hashMapMultiCounter);
        });
        double sum = ((Counter) counterAccumulator.value()).sum();
        for (String str : ((Counter) counterAccumulator.value()).items()) {
            double d = ((Counter) counterAccumulator.value()).get(str);
            HashMap hashMap = new HashMap();
            for (String str2 : ((MultiCounter) multiCounterAccumulator.value()).firstKeys()) {
                double d2 = ((MultiCounter) multiCounterAccumulator.value()).get(str2, str);
                double sum2 = ((MultiCounter) multiCounterAccumulator.value()).get(str2).sum();
                if (d2 > 0.0d) {
                    hashMap.put(str2, Double.valueOf(this.calculator.calculate(ContingencyTable.create2X2(d2, d, sum2, sum))));
                }
            }
            List list = (List) hashMap.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).filter(entry -> {
                return ((Double) entry.getValue()).doubleValue() >= this.threshold;
            }).collect(Collectors.toList());
            if (list.size() > 0) {
                list.subList(0, Math.min(this.numFeaturesPerClass, list.size())).forEach(entry2 -> {
                    hashSet.add((String) entry2.getKey());
                });
            }
        }
        return dataSet.map(datum2 -> {
            Observation observation = (Observation) datum2.get(this.input).copy();
            observation.removeVariables(variable -> {
                return !hashSet.contains(variable.getName());
            });
            datum2.put(this.output, observation);
            return datum2;
        });
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        return observation;
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -2076181469:
                if (implMethodName.equals("lambda$fitAndTransform$cde370d7$1")) {
                    z = false;
                    break;
                }
                break;
            case 859932462:
                if (implMethodName.equals("lambda$fitAndTransform$2d72f4e0$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/transform/ContingencyFeatureSelection") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Set;Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/ml/Datum;")) {
                    ContingencyFeatureSelection contingencyFeatureSelection = (ContingencyFeatureSelection) serializedLambda.getCapturedArg(0);
                    Set set = (Set) serializedLambda.getCapturedArg(1);
                    return datum2 -> {
                        Observation observation = (Observation) datum2.get(this.input).copy();
                        observation.removeVariables(variable -> {
                            return !set.contains(variable.getName());
                        });
                        datum2.put(this.output, observation);
                        return datum2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/transform/ContingencyFeatureSelection") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/stream/MCounterAccumulator;Lcom/gengoai/stream/MMultiCounterAccumulator;Lcom/gengoai/apollo/ml/Datum;)V")) {
                    ContingencyFeatureSelection contingencyFeatureSelection2 = (ContingencyFeatureSelection) serializedLambda.getCapturedArg(0);
                    MCounterAccumulator mCounterAccumulator = (MCounterAccumulator) serializedLambda.getCapturedArg(1);
                    MMultiCounterAccumulator mMultiCounterAccumulator = (MMultiCounterAccumulator) serializedLambda.getCapturedArg(2);
                    return datum -> {
                        String name = datum.get(this.labelSource).asVariable().getName();
                        mCounterAccumulator.increment(name, 1.0d);
                        MultiCounter hashMapMultiCounter = new HashMapMultiCounter();
                        datum.get(this.input).getVariableSpace().forEach(variable -> {
                            hashMapMultiCounter.increment(variable.getName(), name);
                        });
                        mMultiCounterAccumulator.merge(hashMapMultiCounter);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
