package com.gengoai.apollo.ml;

import com.gengoai.Validation;
import com.gengoai.apollo.ml.feature.ObservationExtractor;
import com.gengoai.function.SerializableFunction;
import com.gengoai.stream.MStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/DataSetGenerator.class */
public class DataSetGenerator<T> implements SerializableFunction<T, Datum> {
    private static final long serialVersionUID = 1;
    private final List<GeneratorInfo<T>> generators = new ArrayList();
    private final DataSetType dataSetType;

    /* loaded from: input_file:com/gengoai/apollo/ml/DataSetGenerator$Builder.class */
    public static class Builder<T> {
        protected final List<GeneratorInfo<T>> generators = new ArrayList();

        @NonNull
        protected DataSetType dataSetType = DataSetType.InMemory;

        public DataSetGenerator<T> build() {
            Validation.checkState(this.generators.size() > 0, "No extractors have been specified.");
            return new DataSetGenerator<>(this.dataSetType, this.generators);
        }

        public Builder<T> dataSetType(DataSetType dataSetType) {
            this.dataSetType = dataSetType;
            return this;
        }

        public Builder<T> defaultInput(@NonNull ObservationExtractor<? super T> observationExtractor, @NonNull SerializableFunction<? super T, List<? extends T>> serializableFunction) {
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            if (serializableFunction == null) {
                throw new NullPointerException("toSequence is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(Datum.DEFAULT_INPUT, observationExtractor, serializableFunction));
            return this;
        }

        public Builder<T> defaultInput(@NonNull ObservationExtractor<? super T> observationExtractor) {
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(Datum.DEFAULT_INPUT, observationExtractor, null));
            return this;
        }

        public Builder<T> defaultOutput(@NonNull ObservationExtractor<? super T> observationExtractor, @NonNull SerializableFunction<? super T, List<? extends T>> serializableFunction) {
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            if (serializableFunction == null) {
                throw new NullPointerException("toSequence is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(Datum.DEFAULT_OUTPUT, observationExtractor, serializableFunction));
            return this;
        }

        public Builder<T> defaultOutput(@NonNull ObservationExtractor<? super T> observationExtractor) {
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(Datum.DEFAULT_OUTPUT, observationExtractor, null));
            return this;
        }

        public Builder<T> source(@NonNull String str, @NonNull ObservationExtractor<? super T> observationExtractor) {
            if (str == null) {
                throw new NullPointerException("name is marked non-null but is null");
            }
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(str, observationExtractor, null));
            return this;
        }

        public Builder<T> source(@NonNull String str, @NonNull ObservationExtractor<? super T> observationExtractor, @NonNull SerializableFunction<? super T, List<? extends T>> serializableFunction) {
            if (str == null) {
                throw new NullPointerException("name is marked non-null but is null");
            }
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            if (serializableFunction == null) {
                throw new NullPointerException("toSequence is marked non-null but is null");
            }
            this.generators.add(new GeneratorInfo<>(str, observationExtractor, serializableFunction));
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/gengoai/apollo/ml/DataSetGenerator$GeneratorInfo.class */
    public static final class GeneratorInfo<T> implements Serializable {
        private static final long serialVersionUID = 1;

        @NonNull
        private final String name;

        @NonNull
        private final ObservationExtractor<? super T> extractor;
        private final SerializableFunction<? super T, List<? extends T>> toSequence;

        public GeneratorInfo(@NonNull String str, @NonNull ObservationExtractor<? super T> observationExtractor, SerializableFunction<? super T, List<? extends T>> serializableFunction) {
            if (str == null) {
                throw new NullPointerException("name is marked non-null but is null");
            }
            if (observationExtractor == null) {
                throw new NullPointerException("extractor is marked non-null but is null");
            }
            this.name = str;
            this.extractor = observationExtractor;
            this.toSequence = serializableFunction;
        }

        @NonNull
        public String getName() {
            return this.name;
        }

        @NonNull
        public ObservationExtractor<? super T> getExtractor() {
            return this.extractor;
        }

        public SerializableFunction<? super T, List<? extends T>> getToSequence() {
            return this.toSequence;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof GeneratorInfo)) {
                return false;
            }
            GeneratorInfo generatorInfo = (GeneratorInfo) obj;
            String name = getName();
            String name2 = generatorInfo.getName();
            if (name == null) {
                if (name2 != null) {
                    return false;
                }
            } else if (!name.equals(name2)) {
                return false;
            }
            ObservationExtractor<? super T> extractor = getExtractor();
            ObservationExtractor<? super T> extractor2 = generatorInfo.getExtractor();
            if (extractor == null) {
                if (extractor2 != null) {
                    return false;
                }
            } else if (!extractor.equals(extractor2)) {
                return false;
            }
            SerializableFunction<? super T, List<? extends T>> toSequence = getToSequence();
            SerializableFunction<? super T, List<? extends T>> toSequence2 = generatorInfo.getToSequence();
            return toSequence == null ? toSequence2 == null : toSequence.equals(toSequence2);
        }

        public int hashCode() {
            String name = getName();
            int hashCode = (1 * 59) + (name == null ? 43 : name.hashCode());
            ObservationExtractor<? super T> extractor = getExtractor();
            int hashCode2 = (hashCode * 59) + (extractor == null ? 43 : extractor.hashCode());
            SerializableFunction<? super T, List<? extends T>> toSequence = getToSequence();
            return (hashCode2 * 59) + (toSequence == null ? 43 : toSequence.hashCode());
        }

        public String toString() {
            return "DataSetGenerator.GeneratorInfo(name=" + getName() + ", extractor=" + getExtractor() + ", toSequence=" + getToSequence() + ")";
        }
    }

    public static <T> Builder<T> builder() {
        return new Builder<>();
    }

    protected DataSetGenerator(@NonNull DataSetType dataSetType, @NonNull Collection<GeneratorInfo<T>> collection) {
        if (dataSetType == null) {
            throw new NullPointerException("dataSetType is marked non-null but is null");
        }
        if (collection == null) {
            throw new NullPointerException("generators is marked non-null but is null");
        }
        this.dataSetType = dataSetType;
        this.generators.addAll(collection);
    }

    public Datum apply(T t) {
        Datum datum = new Datum();
        for (GeneratorInfo<T> generatorInfo : this.generators) {
            datum.put(((GeneratorInfo) generatorInfo).name, ((GeneratorInfo) generatorInfo).toSequence == null ? ((GeneratorInfo) generatorInfo).extractor.extractObservation(t) : ((GeneratorInfo) generatorInfo).extractor.extractSequence((List) ((GeneratorInfo) generatorInfo).toSequence.apply(t)));
        }
        return datum;
    }

    public DataSet generate(@NonNull Collection<? extends T> collection) {
        if (collection == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        return this.dataSetType.create((Stream<Datum>) collection.stream().map(this));
    }

    public DataSet generate(@NonNull MStream<? extends T> mStream) {
        if (mStream == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        return this.dataSetType.create(mStream.map(this));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: apply, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Object m12apply(Object obj) {
        return apply((DataSetGenerator<T>) obj);
    }
}
