package com.gengoai.apollo.ml.transform;

import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableCollectionSequence;
import com.gengoai.apollo.ml.observation.VariableList;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.collection.counter.MultiCounter;
import com.gengoai.stream.MCounterAccumulator;
import com.gengoai.stream.MMultiCounterAccumulator;
import java.lang.invoke.SerializedLambda;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/TFIDFTransform.class */
public class TFIDFTransform extends PerPrefixTransform<TFIDFTransform> {
    private static final long serialVersionUID = 1;
    private MultiCounter<String, String> prefixWordDocumentCounts;
    private Counter<String> totalDocuments;

    protected double calculateTFIDF(String str, String str2, double d) {
        return this.prefixWordDocumentCounts.contains(str, str2) ? d * this.prefixWordDocumentCounts.get(str, str2) : d * this.totalDocuments.get(str);
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected void fit(@NonNull String str, @NonNull Iterable<Variable> iterable) {
        if (str == null) {
            throw new NullPointerException("prefix is marked non-null but is null");
        }
        if (iterable == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform, com.gengoai.apollo.ml.transform.Transform
    public DataSet fitAndTransform(DataSet dataSet) {
        MCounterAccumulator counterAccumulator = dataSet.getType().getStreamingContext().counterAccumulator();
        MMultiCounterAccumulator multiCounterAccumulator = dataSet.getType().getStreamingContext().multiCounterAccumulator();
        dataSet.parallelStream().forEach(datum -> {
            Map<String, List<Variable>> splitIntoPrefixes = splitIntoPrefixes(datum.get(this.input));
            for (String str : splitIntoPrefixes.keySet()) {
                counterAccumulator.increment(str, 1.0d);
                splitIntoPrefixes.get(str).stream().map((v0) -> {
                    return v0.getSuffix();
                }).distinct().forEach(str2 -> {
                    multiCounterAccumulator.increment(str, str2);
                });
            }
        });
        this.totalDocuments = (Counter) counterAccumulator.value();
        this.prefixWordDocumentCounts = (MultiCounter) multiCounterAccumulator.value();
        for (String str : this.prefixWordDocumentCounts.firstKeys()) {
            double d = this.totalDocuments.get(str);
            this.prefixWordDocumentCounts.get(str).adjustValuesSelf(d2 -> {
                return Math.log((d + 0.5d) / (d2 + 0.5d));
            });
            this.totalDocuments.set(str, Math.log((d + 0.5d) / 0.5d));
        }
        return transform(dataSet);
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected void reset() {
    }

    private Map<String, List<Variable>> splitIntoPrefixes(Observation observation) {
        return (Map) observation.getVariableSpace().collect(Collectors.groupingBy((v0) -> {
            return v0.getPrefix();
        }, Collectors.toList()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform, com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    public VariableCollection transform(Observation observation) {
        Map<String, List<Variable>> splitIntoPrefixes = splitIntoPrefixes(observation);
        VariableList variableList = new VariableList();
        for (String str : splitIntoPrefixes.keySet()) {
            Counter newCounter = Counters.newCounter(splitIntoPrefixes.get(str).stream().map((v0) -> {
                return v0.getSuffix();
            }));
            newCounter.divideBySum().adjustValuesSelf(d -> {
                return 0.5d + (0.5d * d);
            });
            newCounter.forEach((str2, d2) -> {
                variableList.add(Variable.real(str, str2, calculateTFIDF(str, str2, d2.doubleValue())));
            });
        }
        return variableList;
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform, com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        Observation observation = (Observation) datum.get(this.input);
        if (observation.isSequence()) {
            VariableCollectionSequence variableCollectionSequence = new VariableCollectionSequence();
            for (int i = 0; i < observation.asSequence().size(); i++) {
                variableCollectionSequence.add(transform((Observation) observation.asSequence().get(i)));
            }
            datum.put(this.output, (Observation) variableCollectionSequence);
        } else {
            datum.put(this.output, transform(observation));
        }
        return datum;
    }

    @Override // com.gengoai.apollo.ml.transform.PerPrefixTransform
    protected Variable transform(@NonNull Variable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        return variable;
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 156573256:
                if (implMethodName.equals("lambda$fitAndTransform$90bc938f$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/transform/TFIDFTransform") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/stream/MCounterAccumulator;Lcom/gengoai/stream/MMultiCounterAccumulator;Lcom/gengoai/apollo/ml/Datum;)V")) {
                    TFIDFTransform tFIDFTransform = (TFIDFTransform) serializedLambda.getCapturedArg(0);
                    MCounterAccumulator mCounterAccumulator = (MCounterAccumulator) serializedLambda.getCapturedArg(1);
                    MMultiCounterAccumulator mMultiCounterAccumulator = (MMultiCounterAccumulator) serializedLambda.getCapturedArg(2);
                    return datum -> {
                        Map<String, List<Variable>> splitIntoPrefixes = splitIntoPrefixes(datum.get(this.input));
                        for (String str : splitIntoPrefixes.keySet()) {
                            mCounterAccumulator.increment(str, 1.0d);
                            splitIntoPrefixes.get(str).stream().map((v0) -> {
                                return v0.getSuffix();
                            }).distinct().forEach(str2 -> {
                                mMultiCounterAccumulator.increment(str, str2);
                            });
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
