package com.gengoai.apollo.ml.model;

import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.ObservationMetadata;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.encoder.NoOptEncoder;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.transform.MultiInputTransform;
import com.gengoai.apollo.ml.transform.SingleSourceTransform;
import com.gengoai.apollo.ml.transform.Transform;
import com.gengoai.apollo.ml.transform.Transformer;
import com.gengoai.collection.Sets;
import com.gengoai.conversion.Cast;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/PipelineModel.class */
public class PipelineModel implements Model {
    private static final long serialVersionUID = 1;
    private Model model;

    @NonNull
    private Transformer transformer;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/PipelineModel$Builder.class */
    public static class Builder {
        private List<Transform> transforms = new ArrayList();

        public PipelineModel build(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked non-null but is null");
            }
            return new PipelineModel(model, new Transformer(this.transforms));
        }

        public Builder defaultInput(@NonNull SingleSourceTransform... singleSourceTransformArr) {
            if (singleSourceTransformArr == null) {
                throw new NullPointerException("transforms is marked non-null but is null");
            }
            return source(Datum.DEFAULT_INPUT, singleSourceTransformArr);
        }

        public Builder defaultOutput(@NonNull SingleSourceTransform... singleSourceTransformArr) {
            if (singleSourceTransformArr == null) {
                throw new NullPointerException("transforms is marked non-null but is null");
            }
            return source(Datum.DEFAULT_OUTPUT, singleSourceTransformArr);
        }

        public Builder source(@NonNull String[] strArr, @NonNull MultiInputTransform... multiInputTransformArr) {
            if (strArr == null) {
                throw new NullPointerException("inputs is marked non-null but is null");
            }
            if (multiInputTransformArr == null) {
                throw new NullPointerException("transforms is marked non-null but is null");
            }
            for (MultiInputTransform multiInputTransform : multiInputTransformArr) {
                this.transforms.add(multiInputTransform.inputs(strArr));
            }
            return this;
        }

        public Builder source(@NonNull String str, @NonNull SingleSourceTransform... singleSourceTransformArr) {
            if (str == null) {
                throw new NullPointerException("name is marked non-null but is null");
            }
            if (singleSourceTransformArr == null) {
                throw new NullPointerException("transforms is marked non-null but is null");
            }
            for (SingleSourceTransform singleSourceTransform : singleSourceTransformArr) {
                this.transforms.add(singleSourceTransform.source(str));
            }
            return this;
        }

        public Builder transform(@NonNull Transform transform) {
            if (transform == null) {
                throw new NullPointerException("transform is marked non-null but is null");
            }
            this.transforms.add(transform);
            return this;
        }
    }

    private PipelineModel(@NonNull Model model, @NonNull Transformer transformer) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (transformer == null) {
            throw new NullPointerException("preprocessors is marked non-null but is null");
        }
        this.model = model;
        this.transformer = transformer;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // com.gengoai.apollo.ml.model.Model, com.gengoai.apollo.ml.transform.Transform
    /* renamed from: copy */
    public PipelineModel mo26copy() {
        return new PipelineModel(this.model.mo26copy(), this.transformer.mo26copy());
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.model.estimate(this.transformer.fitAndTransform(dataSet));
    }

    @Override // com.gengoai.apollo.ml.model.Model, com.gengoai.apollo.ml.transform.Transform
    public DataSet fitAndTransform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        estimate(dataSet);
        return this.model.transform(dataSet);
    }

    private Encoder getEncoder(String str) {
        return hasEncoder(str) ? this.transformer.getMetadata().get(str).getEncoder() : NoOptEncoder.INSTANCE;
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public FitParameters<?> getFitParameters() {
        return this.model.getFitParameters();
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Set<String> getInputs() {
        return Sets.union(this.transformer.getInputs(), this.model.getInputs());
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public LabelType getLabelType(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        return this.model.getLabelType(str);
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Set<String> getOutputs() {
        return this.model.getOutputs();
    }

    public <T extends Model> T getWrappedModel() {
        return (T) Cast.as(this.model);
    }

    private boolean hasEncoder(String str) {
        return (getLabelType(str) == LabelType.NDArray || this.transformer.getMetadata().getOrDefault(str, new ObservationMetadata()).getEncoder() == null) ? false : true;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public DataSet transform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        DataSet map = this.transformer.transform(dataSet).map(this::transform);
        for (String str : this.model.getOutputs()) {
            map.updateMetadata(str, observationMetadata -> {
                observationMetadata.setType(getLabelType(str).getObservationClass());
                observationMetadata.setEncoder(getEncoder(str));
                if (getLabelType(str) != LabelType.NDArray) {
                    observationMetadata.setDimension(observationMetadata.getEncoder().size());
                }
            });
        }
        return map;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        Datum transform = this.model.transform(this.transformer.transform(datum));
        for (String str : this.model.getOutputs()) {
            Observation observation = transform.get(str);
            if (observation.isNDArray() && hasEncoder(str)) {
                transform.put(str, getLabelType(str).transform(getEncoder(str), observation));
            }
        }
        return transform;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1052666732:
                if (implMethodName.equals("transform")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/PipelineModel") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/ml/Datum;")) {
                    PipelineModel pipelineModel = (PipelineModel) serializedLambda.getCapturedArg(0);
                    return pipelineModel::transform;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
