package com.gengoai.apollo.ml.transform;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.ObservationMetadata;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.collection.Sets;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/VectorConcatenation.class */
public class VectorConcatenation implements Transform {
    private static final long serialVersionUID = 1;
    protected NDArrayFactory factory;
    private final LinkedHashSet<String> inputs = Sets.linkedHashSetOf(new String[]{Datum.DEFAULT_INPUT});
    private String output = Datum.DEFAULT_INPUT;
    private boolean dropInputs = true;

    @Override // com.gengoai.apollo.ml.transform.Transform
    public DataSet fitAndTransform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        return transform(dataSet);
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Set<String> getInputs() {
        return Collections.unmodifiableSet(this.inputs);
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Set<String> getOutputs() {
        return Collections.singleton(this.output);
    }

    public VectorConcatenation inputs(@NonNull String... strArr) {
        if (strArr == null) {
            throw new NullPointerException("inputs is marked non-null but is null");
        }
        return inputs(Arrays.asList(strArr));
    }

    public VectorConcatenation inputs(@NonNull List<String> list) {
        if (list == null) {
            throw new NullPointerException("inputs is marked non-null but is null");
        }
        this.inputs.clear();
        this.inputs.addAll(list);
        return this;
    }

    public VectorConcatenation output(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("output is marked non-null but is null");
        }
        this.output = str;
        return this;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = this.inputs.iterator();
        while (it.hasNext()) {
            arrayList.add(datum.get(it.next()).asNDArray());
        }
        datum.put(this.output, (Observation) this.factory.hstack(arrayList));
        if (this.dropInputs) {
            Stream filter = this.inputs.stream().filter(str -> {
                return !this.output.equals(str);
            });
            Objects.requireNonNull(datum);
            filter.forEach((v1) -> {
                r1.remove(v1);
            });
        }
        return datum;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public DataSet transform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.factory = dataSet.getNDArrayFactory();
        DataSet map = dataSet.map(this::transform);
        int sum = (int) map.getMetadata().entrySet().stream().filter(entry -> {
            return this.inputs.contains(entry.getKey());
        }).mapToLong(entry2 -> {
            return ((ObservationMetadata) entry2.getValue()).getDimension();
        }).sum();
        if (this.dropInputs) {
            LinkedHashSet<String> linkedHashSet = this.inputs;
            Objects.requireNonNull(map);
            linkedHashSet.forEach(map::removeMetadata);
        }
        map.updateMetadata(this.output, observationMetadata -> {
            observationMetadata.setDimension(sum);
            observationMetadata.setType(NDArray.class);
            observationMetadata.setEncoder(null);
        });
        return map;
    }

    public boolean dropInputs() {
        return this.dropInputs;
    }

    public VectorConcatenation dropInputs(boolean z) {
        this.dropInputs = z;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1052666732:
                if (implMethodName.equals("transform")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/transform/VectorConcatenation") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/ml/Datum;")) {
                    VectorConcatenation vectorConcatenation = (VectorConcatenation) serializedLambda.getCapturedArg(0);
                    return vectorConcatenation::transform;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
