package org.tribuo;

import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.tribuo.Output;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
import org.tribuo.transform.TransformationMap;
import org.tribuo.transform.TransformerMap;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/Dataset.class */
public abstract class Dataset<T extends Output<T>> implements Iterable<Example<T>>, Provenancable<DatasetProvenance>, Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = Logger.getLogger(Dataset.class.getName());
    private static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);
    protected final List<Example<T>> data;
    protected final DataProvenance sourceProvenance;
    protected final OutputFactory<T> outputFactory;
    protected int[] indices;

    /* loaded from: input_file:org/tribuo/Dataset$ShuffleIterator.class */
    private static class ShuffleIterator<T extends Output<T>> implements Iterator<Example<T>> {
        private final Dataset<T> data;
        private final int[] indices;
        private int index = 0;

        public ShuffleIterator(Dataset<T> dataset, int[] iArr) {
            this.data = dataset;
            this.indices = iArr;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.index < this.indices.length;
        }

        @Override // java.util.Iterator
        public Example<T> next() {
            Example<T> example = this.data.getExample(this.indices[this.index]);
            this.index++;
            return example;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Dataset(DataProvenance dataProvenance, OutputFactory<T> outputFactory) {
        this.data = new ArrayList();
        this.indices = null;
        this.sourceProvenance = dataProvenance;
        this.outputFactory = outputFactory;
    }

    protected Dataset(DataSource<T> dataSource) {
        this(dataSource.getProvenance(), dataSource.getOutputFactory());
    }

    public String getSourceDescription() {
        return "Dataset(source=" + this.sourceProvenance.toString() + ")";
    }

    public DataProvenance getSourceProvenance() {
        return this.sourceProvenance;
    }

    public List<Example<T>> getData() {
        return Collections.unmodifiableList(this.data);
    }

    public OutputFactory<T> getOutputFactory() {
        return this.outputFactory;
    }

    public abstract Set<T> getOutputs();

    public Example<T> getExample(int i) {
        if (i < 0 || i >= size()) {
            throw new IllegalArgumentException("Example index " + i + " is out of bounds.");
        }
        return this.data.get(i);
    }

    public int size() {
        return this.data.size();
    }

    public synchronized void shuffle(boolean z) {
        if (z) {
            this.indices = Util.randperm(this.data.size(), rng);
        } else {
            this.indices = null;
        }
    }

    public abstract ImmutableOutputInfo<T> getOutputIDInfo();

    public abstract OutputInfo<T> getOutputInfo();

    public abstract ImmutableFeatureMap getFeatureIDMap();

    public abstract FeatureMap getFeatureMap();

    @Override // java.lang.Iterable
    public synchronized Iterator<Example<T>> iterator() {
        return this.indices == null ? this.data.iterator() : new ShuffleIterator(this, this.indices);
    }

    public String toString() {
        return "Dataset(source=" + this.sourceProvenance + ")";
    }

    public TransformerMap createTransformers(TransformationMap transformationMap) {
        return createTransformers(transformationMap, false);
    }

    public TransformerMap createTransformers(TransformationMap transformationMap, boolean z) {
        ArrayList arrayList = new ArrayList(getFeatureMap().keySet());
        logger.fine(String.format("Processing %d feature specific transforms", Integer.valueOf(transformationMap.getFeatureTransformations().size())));
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, List<Transformation>> entry : transformationMap.getFeatureTransformations().entrySet()) {
            Pattern compile = Pattern.compile(entry.getKey());
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                String str = (String) it.next();
                if (compile.matcher(str).matches() && ((List) hashMap.put(str, entry.getValue())) != null) {
                    throw new IllegalArgumentException("Feature name '" + str + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
                }
            }
        }
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            LinkedList linkedList = new LinkedList();
            Iterator it2 = ((List) entry2.getValue()).iterator();
            while (it2.hasNext()) {
                linkedList.add(((Transformation) it2.next()).createStats());
            }
            hashMap2.put((String) entry2.getKey(), linkedList);
            hashMap3.put((String) entry2.getKey(), new MutableLong(this.data.size()));
        }
        if (!transformationMap.getGlobalTransformations().isEmpty()) {
            int size = arrayList.size();
            logger.fine(String.format("Starting %,d global transformations", Integer.valueOf(size)));
            int i = 0;
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                String str2 = (String) it3.next();
                Queue queue = (Queue) hashMap2.computeIfAbsent(str2, str3 -> {
                    return new LinkedList();
                });
                Iterator<Transformation> it4 = transformationMap.getGlobalTransformations().iterator();
                while (it4.hasNext()) {
                    queue.add(it4.next().createStats());
                }
                hashMap2.put(str2, queue);
                hashMap3.putIfAbsent(str2, new MutableLong(this.data.size()));
                i++;
                if (logger.isLoggable(Level.FINE) && i % 10000 == 0) {
                    logger.fine(String.format("Completed %,d of %,d global transformations", Integer.valueOf(i), Integer.valueOf(size)));
                }
            }
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        boolean z2 = false;
        while (!hashMap2.isEmpty()) {
            Iterator<Example<T>> it5 = this.data.iterator();
            while (it5.hasNext()) {
                Iterator<Feature> it6 = it5.next().iterator();
                while (it6.hasNext()) {
                    Feature next = it6.next();
                    if (hashMap2.containsKey(next.getName())) {
                        if (!z2) {
                            ((MutableLong) hashMap3.get(next.getName())).decrement();
                        }
                        ((TransformStatistics) ((Queue) hashMap2.get(next.getName())).peek()).observeValue(TransformerMap.applyTransformerList(next.getValue(), (List) linkedHashMap.get(next.getName())));
                    }
                }
            }
            z2 = true;
            linkedHashSet.clear();
            for (Map.Entry entry3 : hashMap2.entrySet()) {
                TransformStatistics transformStatistics = (TransformStatistics) ((Queue) entry3.getValue()).poll();
                if (z) {
                    transformStatistics.observeSparse(((MutableLong) hashMap3.get(entry3.getKey())).intValue());
                }
                ((List) linkedHashMap.computeIfAbsent((String) entry3.getKey(), str4 -> {
                    return new ArrayList();
                })).add(transformStatistics.generateTransformer());
                if (((Queue) entry3.getValue()).isEmpty()) {
                    linkedHashSet.add((String) entry3.getKey());
                }
            }
            Iterator it7 = linkedHashSet.iterator();
            while (it7.hasNext()) {
                hashMap2.remove((String) it7.next());
            }
        }
        return new TransformerMap(linkedHashMap, getProvenance(), transformationMap.m59getProvenance());
    }

    public boolean validate(Class<? extends Output<?>> cls) {
        boolean z = true;
        Iterator<T> it = getOutputInfo().getDomain().iterator();
        while (it.hasNext()) {
            z &= cls.isInstance(it.next());
        }
        return z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T extends Output<T>> Dataset<T> castDataset(Dataset<?> dataset, Class<T> cls) {
        if (dataset.validate(cls)) {
            return dataset;
        }
        throw new ClassCastException("Attempted to cast dataset to " + cls.getName() + " which is not valid for dataset " + dataset.toString());
    }
}
