package org.tribuo.dataset;

import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.function.Predicate;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/dataset/DatasetView.class */
public final class DatasetView<T extends Output<T>> extends ImmutableDataset<T> {
    private static final long serialVersionUID = 1;
    private final Dataset<T> innerDataset;
    private final int size;
    private final int[] exampleIndices;
    private final long seed;
    private final String tag;
    private final boolean sampled;
    private final boolean weighted;
    private boolean storeIndices;

    /* loaded from: input_file:org/tribuo/dataset/DatasetView$DatasetViewProvenance.class */
    public static final class DatasetViewProvenance extends DatasetProvenance {
        private static final long serialVersionUID = 1;
        private static final String SIZE = "size";
        private static final String SEED = "seed";
        private static final String TAG = "tag";
        private static final String SAMPLED = "sampled";
        private static final String WEIGHTED = "weighted";
        private static final String INDICES = "indices";
        private final IntProvenance size;
        private final LongProvenance seed;
        private final StringProvenance tag;
        private final BooleanProvenance weighted;
        private final BooleanProvenance sampled;
        private final int[] indices;

        <T extends Output<T>> DatasetViewProvenance(DatasetView<T> datasetView, boolean z) {
            super(((DatasetView) datasetView).sourceProvenance, (ListProvenance<ObjectProvenance>) new ListProvenance(), datasetView);
            this.size = new IntProvenance(SIZE, ((DatasetView) datasetView).size);
            this.seed = new LongProvenance(SEED, ((DatasetView) datasetView).seed);
            this.weighted = new BooleanProvenance(WEIGHTED, ((DatasetView) datasetView).weighted);
            this.sampled = new BooleanProvenance(SAMPLED, ((DatasetView) datasetView).sampled);
            this.tag = new StringProvenance(TAG, ((DatasetView) datasetView).tag);
            this.indices = z ? ((DatasetView) datasetView).indices : new int[0];
        }

        public DatasetViewProvenance(Map<String, Provenance> map) {
            super(map);
            this.size = ObjectProvenance.checkAndExtractProvenance(map, SIZE, IntProvenance.class, DatasetViewProvenance.class.getSimpleName());
            this.seed = ObjectProvenance.checkAndExtractProvenance(map, SEED, LongProvenance.class, DatasetViewProvenance.class.getSimpleName());
            this.tag = ObjectProvenance.checkAndExtractProvenance(map, TAG, StringProvenance.class, DatasetViewProvenance.class.getSimpleName());
            this.weighted = ObjectProvenance.checkAndExtractProvenance(map, WEIGHTED, BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
            this.sampled = ObjectProvenance.checkAndExtractProvenance(map, SAMPLED, BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
            ListProvenance checkAndExtractProvenance = ObjectProvenance.checkAndExtractProvenance(map, INDICES, ListProvenance.class, DatasetViewProvenance.class.getSimpleName());
            if (checkAndExtractProvenance.getList().size() > 0) {
                try {
                } catch (ClassCastException e) {
                    throw new ProvenanceException("Loaded another class when expecting an ListProvenance<IntProvenance>", e);
                }
            }
            this.indices = Util.toPrimitiveInt(ProvenanceUtil.unwrap(checkAndExtractProvenance));
        }

        public int[] generateBootstrap() {
            return Util.generateBootstrapIndices(this.size.getValue().intValue(), new SplittableRandom(this.seed.getValue().longValue()));
        }

        public boolean isSampled() {
            return this.sampled.getValue().booleanValue();
        }

        public boolean isWeighted() {
            return this.weighted.getValue().booleanValue();
        }

        @Override // org.tribuo.provenance.DatasetProvenance
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof DatasetViewProvenance) || !super.equals(obj)) {
                return false;
            }
            DatasetViewProvenance datasetViewProvenance = (DatasetViewProvenance) obj;
            return this.size.equals(datasetViewProvenance.size) && this.seed.equals(datasetViewProvenance.seed) && this.tag.equals(datasetViewProvenance.tag);
        }

        @Override // org.tribuo.provenance.DatasetProvenance
        public int hashCode() {
            return Objects.hash(Integer.valueOf(super.hashCode()), this.size, this.seed, this.tag);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.tribuo.provenance.DatasetProvenance
        public List<Pair<String, Provenance>> allProvenances() {
            List<Pair<String, Provenance>> allProvenances = super.allProvenances();
            allProvenances.add(new Pair<>(SIZE, this.size));
            allProvenances.add(new Pair<>(SEED, this.seed));
            allProvenances.add(new Pair<>(TAG, this.tag));
            allProvenances.add(new Pair<>(WEIGHTED, this.weighted));
            allProvenances.add(new Pair<>(SAMPLED, this.sampled));
            allProvenances.add(new Pair<>(INDICES, boxArray()));
            return allProvenances;
        }

        private ListProvenance<IntProvenance> boxArray() {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.indices.length; i++) {
                arrayList.add(new IntProvenance(INDICES, this.indices[i]));
            }
            return new ListProvenance<>(arrayList);
        }

        @Override // org.tribuo.provenance.DatasetProvenance
        public String toString() {
            List<Pair<String, Provenance>> allProvenances = super.allProvenances();
            allProvenances.add(new Pair<>(SIZE, this.size));
            allProvenances.add(new Pair<>(SEED, this.seed));
            allProvenances.add(new Pair<>(TAG, this.tag));
            allProvenances.add(new Pair<>(WEIGHTED, this.weighted));
            allProvenances.add(new Pair<>(SAMPLED, this.sampled));
            allProvenances.add(new Pair<>(INDICES, new ListProvenance()));
            StringBuilder sb = new StringBuilder();
            sb.append("DatasetView(");
            for (Pair<String, Provenance> pair : allProvenances) {
                sb.append((String) pair.getA());
                sb.append('=');
                sb.append(((Provenance) pair.getB()).toString());
                sb.append(',');
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            return sb.toString();
        }
    }

    /* loaded from: input_file:org/tribuo/dataset/DatasetView$ViewIterator.class */
    private static class ViewIterator<T extends Output<T>> implements Iterator<Example<T>> {
        private int counter = 0;
        private final DatasetView<T> dataset;

        ViewIterator(DatasetView<T> datasetView) {
            this.dataset = datasetView;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.counter < this.dataset.size();
        }

        @Override // java.util.Iterator
        public Example<T> next() {
            Example<T> example = this.dataset.getExample(this.counter);
            this.counter++;
            return example;
        }
    }

    public DatasetView(Dataset<T> dataset, int[] iArr, String str) {
        this(dataset, iArr, dataset.getFeatureIDMap(), dataset.getOutputIDInfo(), str);
    }

    public DatasetView(Dataset<T> dataset, int[] iArr, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, String str) {
        super(dataset.getProvenance(), dataset.getOutputFactory(), immutableFeatureMap, immutableOutputInfo);
        this.storeIndices = false;
        if (!validateIndices(dataset.size(), iArr)) {
            throw new IllegalArgumentException("Invalid indices supplied, dataset.size() = " + dataset.size() + ", but found a negative index or a value greater than or equal to size.");
        }
        this.innerDataset = dataset;
        this.size = iArr.length;
        this.exampleIndices = iArr;
        this.seed = -1L;
        this.tag = str;
        this.storeIndices = true;
        this.sampled = false;
        this.weighted = false;
    }

    private DatasetView(Dataset<T> dataset, int[] iArr, long j, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, boolean z) {
        super(dataset.getProvenance(), dataset.getOutputFactory(), immutableFeatureMap, immutableOutputInfo);
        this.storeIndices = false;
        this.innerDataset = dataset;
        this.size = iArr.length;
        this.exampleIndices = iArr;
        this.tag = "";
        this.seed = j;
        this.sampled = true;
        this.weighted = z;
        this.storeIndices = z;
    }

    public static <T extends Output<T>> DatasetView<T> createView(Dataset<T> dataset, Predicate<Example<T>> predicate, String str) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Iterator<Example<T>> it = dataset.iterator();
        while (it.hasNext()) {
            if (predicate.test(it.next())) {
                arrayList.add(Integer.valueOf(i));
            }
            i++;
        }
        return new DatasetView<>(dataset, Util.toPrimitiveInt(arrayList), str);
    }

    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int i, long j) {
        return createBootstrapView(dataset, i, j, dataset.getFeatureIDMap(), dataset.getOutputIDInfo());
    }

    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int i, long j, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo) {
        return new DatasetView<>(dataset, Util.generateBootstrapIndices(i, new SplittableRandom(j)), j, immutableFeatureMap, immutableOutputInfo, false);
    }

    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int i, long j, float[] fArr) {
        return createWeightedBootstrapView(dataset, i, j, fArr, dataset.getFeatureIDMap(), dataset.getOutputIDInfo());
    }

    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int i, long j, float[] fArr, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo) {
        if (dataset.size() != fArr.length) {
            throw new IllegalArgumentException("There must be a weight for each example, dataset.size()=" + dataset.size() + ", exampleWeights.length=" + fArr.length);
        }
        return new DatasetView<>(dataset, Util.generateWeightedIndicesSample(i, fArr, new SplittableRandom(j)), j, immutableFeatureMap, immutableOutputInfo, true);
    }

    public boolean storeIndicesInProvenance() {
        return this.storeIndices;
    }

    public void setStoreIndices(boolean z) {
        this.storeIndices = z;
    }

    @Override // org.tribuo.ImmutableDataset, org.tribuo.Dataset
    public String toString() {
        return "DatasetView(innerDataset=" + this.innerDataset.getSourceDescription() + ",size=" + this.size + ",seed=" + this.seed + ",tag=" + this.tag + ")";
    }

    @Override // org.tribuo.ImmutableDataset, org.tribuo.Dataset
    public Set<T> getOutputs() {
        return this.innerDataset.getOutputs();
    }

    @Override // org.tribuo.Dataset
    public int size() {
        return this.size;
    }

    @Override // org.tribuo.ImmutableDataset, org.tribuo.Dataset
    public ImmutableFeatureMap getFeatureMap() {
        return this.featureIDMap;
    }

    @Override // org.tribuo.ImmutableDataset, org.tribuo.Dataset
    public ImmutableOutputInfo<T> getOutputInfo() {
        return this.outputIDInfo;
    }

    @Override // org.tribuo.Dataset, java.lang.Iterable
    public Iterator<Example<T>> iterator() {
        return new ViewIterator(this);
    }

    @Override // org.tribuo.Dataset
    public List<Example<T>> getData() {
        ArrayList arrayList = new ArrayList();
        for (int i : this.exampleIndices) {
            arrayList.add(this.innerDataset.getExample(i));
        }
        return Collections.unmodifiableList(arrayList);
    }

    @Override // org.tribuo.Dataset
    public Example<T> getExample(int i) {
        if (i < 0 || i >= size()) {
            throw new IllegalArgumentException("Example index " + i + " is out of bounds.");
        }
        return this.innerDataset.getExample(this.exampleIndices[i]);
    }

    @Override // org.tribuo.ImmutableDataset
    /* renamed from: getProvenance */
    public DatasetViewProvenance mo5getProvenance() {
        return new DatasetViewProvenance(this, this.storeIndices);
    }

    public int[] getExampleIndices() {
        return Arrays.copyOf(this.exampleIndices, this.exampleIndices.length);
    }

    private static boolean validateIndices(int i, int[] iArr) {
        boolean z = true;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            z &= i3 < i && i3 > -1;
        }
        return z;
    }
}
