package org.apache.beam.sdk.extensions.smb;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.extensions.smb.BucketMetadataUtil;
import org.apache.beam.sdk.extensions.smb.SortedBucketIO;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.ResolveOptions;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
import org.apache.beam.sdk.transforms.join.UnionCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource.class */
public abstract class SortedBucketSource<KeyType> extends BoundedSource<KV<KeyType, CoGbkResult>> {
    static final Double DESIRED_SIZE_BYTES_ADJUSTMENT_FACTOR = Double.valueOf(0.5d);
    private static final Logger LOG = LoggerFactory.getLogger(SortedBucketSource.class);
    private static final AtomicInteger metricsId = new AtomicInteger(1);
    protected final List<BucketedInput<?>> sources;
    protected final TargetParallelism targetParallelism;
    protected final int effectiveParallelism;
    protected final int bucketOffsetId;
    protected SourceSpec sourceSpec;
    protected final Distribution keyGroupSize;
    protected Long estimatedSizeBytes;
    protected final String metricsKey;

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$BucketedInput.class */
    public static abstract class BucketedInput<V> implements Serializable {
        private static final Pattern BUCKET_PATTERN = Pattern.compile("(\\d+)-of-(\\d+)");
        protected TupleTag<V> tupleTag;
        protected String filenameSuffix;
        protected FileOperations<V> fileOperations;
        protected List<String> inputDirectories;
        protected Predicate<V> predicate;
        protected Keying keying;
        protected transient BucketMetadataUtil.SourceMetadata<V> sourceMetadata = null;

        public static <V> BucketedInput<V> of(Keying keying, TupleTag<V> tupleTag, List<String> list, String str, FileOperations<V> fileOperations, Predicate<V> predicate) {
            return keying == Keying.PRIMARY ? new PrimaryKeyedBucketedInput(tupleTag, list, str, fileOperations, predicate) : new PrimaryAndSecondaryKeyedBucktedInput(tupleTag, list, str, fileOperations, predicate);
        }

        public BucketedInput(Keying keying, TupleTag<V> tupleTag, List<String> list, String str, FileOperations<V> fileOperations, Predicate<V> predicate) {
            this.keying = keying;
            this.tupleTag = tupleTag;
            this.filenameSuffix = str;
            this.fileOperations = fileOperations;
            this.inputDirectories = list;
            this.predicate = predicate;
        }

        public abstract BucketMetadataUtil.SourceMetadata<V> getSourceMetadata();

        public TupleTag<V> getTupleTag() {
            return this.tupleTag;
        }

        public Predicate<V> getPredicate() {
            return this.predicate;
        }

        public Coder<V> getCoder() {
            return this.fileOperations.getCoder();
        }

        static CoGbkResultSchema schemaOf(List<BucketedInput<?>> list) {
            return CoGbkResultSchema.of((List) list.stream().map((v0) -> {
                return v0.getTupleTag();
            }).collect(Collectors.toList()));
        }

        private static List<MatchResult.Metadata> sampleDirectory(String str, String str2) {
            try {
                return FileSystems.match(FileSystems.matchNewResource(str, true).resolve(str2, ResolveOptions.StandardResolveOptions.RESOLVE_FILE).toString()).metadata();
            } catch (FileNotFoundException e) {
                return Collections.emptyList();
            } catch (IOException e2) {
                throw new RuntimeException("Exception fetching metadata for " + str, e2);
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getOrSampleByteSize() {
            return this.inputDirectories.parallelStream().mapToLong(str -> {
                List<MatchResult.Metadata> sampleDirectory = sampleDirectory(str, "*-0000?-of-?????" + this.filenameSuffix);
                if (sampleDirectory.isEmpty()) {
                    sampleDirectory = sampleDirectory(str, "*-0000?-of-*-shard-00000-of-?????" + this.filenameSuffix);
                }
                int i = 0;
                long j = 0;
                HashSet hashSet = new HashSet();
                for (MatchResult.Metadata metadata : sampleDirectory) {
                    Matcher matcher = BUCKET_PATTERN.matcher(metadata.resourceId().getFilename());
                    if (!matcher.find()) {
                        throw new RuntimeException("Couldn't match bucket information from filename: " + metadata.resourceId().getFilename());
                    }
                    hashSet.add(matcher.group(1));
                    if (i == 0) {
                        i = Integer.parseInt(matcher.group(2));
                    }
                    j += metadata.sizeBytes();
                }
                if (i == 0) {
                    throw new IllegalArgumentException("Directory " + str + " has no bucket files");
                }
                return hashSet.size() < i ? (long) (j * (i / (hashSet.size() * 1.0d))) : j;
            }).sum();
        }

        public KeyGroupIterator<V> createIterator(int i, int i2, PipelineOptions pipelineOptions) {
            BucketMetadataUtil.SourceMetadata<V> sourceMetadata = getSourceMetadata();
            Comparator primaryKeyComparator = this.keying == Keying.PRIMARY ? new SortedBucketIO.PrimaryKeyComparator() : new SortedBucketIO.PrimaryAndSecondaryKeyComparator();
            SortedBucketOptions sortedBucketOptions = (SortedBucketOptions) pipelineOptions.as(SortedBucketOptions.class);
            int sortedBucketReadBufferSize = sortedBucketOptions.getSortedBucketReadBufferSize();
            FileOperations.setDiskBufferMb(sortedBucketOptions.getSortedBucketReadDiskBufferMb());
            ArrayList arrayList = new ArrayList();
            sourceMetadata.mapping.forEach((resourceId, sourceMetadataValue) -> {
                Function function;
                int numBuckets = sourceMetadataValue.metadata.getNumBuckets();
                int numShards = sourceMetadataValue.metadata.getNumShards();
                if (this.keying == Keying.PRIMARY) {
                    BucketMetadata<?, ?, V> bucketMetadata = sourceMetadataValue.metadata;
                    Objects.requireNonNull(bucketMetadata);
                    function = bucketMetadata::primaryComparableKeyBytes;
                } else {
                    BucketMetadata<?, ?, V> bucketMetadata2 = sourceMetadataValue.metadata;
                    Objects.requireNonNull(bucketMetadata2);
                    function = bucketMetadata2::primaryAndSecondaryComparableKeyBytes;
                }
                Function function2 = function;
                int i3 = i % numBuckets;
                while (true) {
                    int i4 = i3;
                    if (i4 >= numBuckets) {
                        return;
                    }
                    for (int i5 = 0; i5 < numShards; i5++) {
                        try {
                            Iterator transform = Iterators.transform(this.fileOperations.iterator(sourceMetadataValue.fileAssignment.forBucket(BucketShardId.of(i4, i5), numBuckets, numShards)), obj -> {
                                return KV.of((SortedBucketIO.ComparableKeyBytes) function2.apply(obj), obj);
                            });
                            arrayList.add(sortedBucketReadBufferSize > 0 ? new BufferedIterator(transform, sortedBucketReadBufferSize) : transform);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    i3 = i4 + i2;
                }
            });
            return new KeyGroupIterator<>(arrayList, primaryKeyComparator);
        }

        public String toString() {
            Object[] objArr = new Object[2];
            objArr[0] = this.tupleTag.getId();
            objArr[1] = this.inputDirectories.size() > 5 ? this.inputDirectories.subList(0, 4) + "..." + this.inputDirectories.get(this.inputDirectories.size() - 1) : this.inputDirectories;
            return String.format("BucketedInput[tupleTag=%s, inputDirectories=[%s]]", objArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            SerializableCoder.of(TupleTag.class).encode(this.tupleTag, objectOutputStream);
            ListCoder.of(StringUtf8Coder.of()).encode(this.inputDirectories, objectOutputStream);
            objectOutputStream.writeUTF(this.filenameSuffix);
            objectOutputStream.writeObject(this.fileOperations);
            objectOutputStream.writeObject(this.predicate);
            objectOutputStream.writeObject(this.keying);
            objectOutputStream.flush();
        }

        private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
            this.tupleTag = SerializableCoder.of(TupleTag.class).decode(objectInputStream);
            this.inputDirectories = (List) ListCoder.of(StringUtf8Coder.of()).decode(objectInputStream);
            this.filenameSuffix = objectInputStream.readUTF();
            this.fileOperations = (FileOperations) objectInputStream.readObject();
            this.predicate = (Predicate) objectInputStream.readObject();
            this.keying = (Keying) objectInputStream.readObject();
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$Keying.class */
    public enum Keying {
        PRIMARY,
        PRIMARY_AND_SECONDARY
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$MergeBucketsReader.class */
    static class MergeBucketsReader<KeyType> extends BoundedSource.BoundedReader<KV<KeyType, CoGbkResult>> {
        private final SortedBucketSource<KeyType> currentSource;
        private final MultiSourceKeyGroupReader<KeyType> iter;
        private KV<KeyType, CoGbkResult> next = null;

        MergeBucketsReader(MultiSourceKeyGroupReader<KeyType> multiSourceKeyGroupReader, SortedBucketSource<KeyType> sortedBucketSource) {
            this.currentSource = sortedBucketSource;
            this.iter = multiSourceKeyGroupReader;
        }

        public boolean start() throws IOException {
            return advance();
        }

        /* renamed from: getCurrent, reason: merged with bridge method [inline-methods] */
        public KV<KeyType, CoGbkResult> m38getCurrent() throws NoSuchElementException {
            if (this.next == null) {
                throw new NoSuchElementException();
            }
            return this.next;
        }

        public boolean advance() throws IOException {
            this.next = this.iter.readNext();
            return this.next != null;
        }

        public void close() throws IOException {
        }

        /* renamed from: getCurrentSource, reason: merged with bridge method [inline-methods] */
        public BoundedSource<KV<KeyType, CoGbkResult>> m37getCurrentSource() {
            return this.currentSource;
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$Predicate.class */
    public interface Predicate<T> extends BiFunction<List<T>, T, Boolean>, Serializable {
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$PrimaryAndSecondaryKeyedBucktedInput.class */
    public static class PrimaryAndSecondaryKeyedBucktedInput<V> extends BucketedInput<V> {
        public PrimaryAndSecondaryKeyedBucktedInput(TupleTag<V> tupleTag, List<String> list, String str, FileOperations<V> fileOperations, Predicate<V> predicate) {
            super(Keying.PRIMARY_AND_SECONDARY, tupleTag, list, str, fileOperations, predicate);
        }

        @Override // org.apache.beam.sdk.extensions.smb.SortedBucketSource.BucketedInput
        public BucketMetadataUtil.SourceMetadata<V> getSourceMetadata() {
            if (this.sourceMetadata == null) {
                this.sourceMetadata = BucketMetadataUtil.get().getPrimaryAndSecondaryKeyedSourceMetadata(this.inputDirectories, this.filenameSuffix);
            }
            return this.sourceMetadata;
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$PrimaryKeyedBucketedInput.class */
    public static class PrimaryKeyedBucketedInput<V> extends BucketedInput<V> {
        public PrimaryKeyedBucketedInput(TupleTag<V> tupleTag, List<String> list, String str, FileOperations<V> fileOperations, Predicate<V> predicate) {
            super(Keying.PRIMARY, tupleTag, list, str, fileOperations, predicate);
        }

        @Override // org.apache.beam.sdk.extensions.smb.SortedBucketSource.BucketedInput
        public BucketMetadataUtil.SourceMetadata<V> getSourceMetadata() {
            if (this.sourceMetadata == null) {
                this.sourceMetadata = BucketMetadataUtil.get().getPrimaryKeyedSourceMetadata(this.inputDirectories, this.filenameSuffix);
            }
            return this.sourceMetadata;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$TraversableOnceIterable.class */
    public static class TraversableOnceIterable<V> implements Iterable<V> {
        private final Iterator<V> underlying;
        private boolean called = false;

        /* JADX INFO: Access modifiers changed from: package-private */
        public TraversableOnceIterable(Iterator<V> it) {
            this.underlying = it;
        }

        @Override // java.lang.Iterable
        public Iterator<V> iterator() {
            Preconditions.checkArgument(!this.called, "CoGbkResult .iterator() can only be called once. To be re-iterable, it must be materialized as a List.");
            this.called = true;
            return this.underlying;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void ensureExhausted() {
            this.underlying.forEachRemaining(obj -> {
            });
        }
    }

    public SortedBucketSource(List<BucketedInput<?>> list) {
        this(list, TargetParallelism.auto());
    }

    public SortedBucketSource(List<BucketedInput<?>> list, TargetParallelism targetParallelism) {
        this(list, targetParallelism, 0, 1, null);
    }

    public SortedBucketSource(List<BucketedInput<?>> list, TargetParallelism targetParallelism, String str) {
        this(list, targetParallelism, 0, 1, str);
    }

    private SortedBucketSource(List<BucketedInput<?>> list, TargetParallelism targetParallelism, int i, int i2, String str) {
        this(list, targetParallelism, i, i2, str, null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SortedBucketSource(List<BucketedInput<?>> list, TargetParallelism targetParallelism, int i, int i2, String str, Long l) {
        this.sources = list;
        this.targetParallelism = targetParallelism == null ? TargetParallelism.auto() : targetParallelism;
        this.bucketOffsetId = i;
        this.effectiveParallelism = i2;
        this.metricsKey = str == null ? getDefaultMetricsKey() : str;
        this.keyGroupSize = Metrics.distribution(SortedBucketSource.class, this.metricsKey + "-KeyGroupSize");
        this.estimatedSizeBytes = l;
    }

    protected abstract Coder<KeyType> keyTypeCoder();

    protected abstract Function<SortedBucketIO.ComparableKeyBytes, KeyType> toKeyFn();

    protected abstract SortedBucketSource<KeyType> createSplitSource(int i, int i2, long j);

    protected abstract Comparator<SortedBucketIO.ComparableKeyBytes> comparator();

    private static String getDefaultMetricsKey() {
        return "SortedBucketSource{" + metricsId.getAndAdd(1) + "}";
    }

    @VisibleForTesting
    int getBucketOffset() {
        return this.bucketOffsetId;
    }

    @VisibleForTesting
    int getEffectiveParallelism() {
        return this.effectiveParallelism;
    }

    protected SourceSpec getOrComputeSourceSpec() {
        if (this.sourceSpec == null) {
            this.sourceSpec = SourceSpec.from(this.sources);
        }
        return this.sourceSpec;
    }

    protected CoGbkResultSchema coGbkResultSchema() {
        return CoGbkResultSchema.of((List) this.sources.stream().map((v0) -> {
            return v0.getTupleTag();
        }).collect(Collectors.toList()));
    }

    public Coder<KV<KeyType, CoGbkResult>> getOutputCoder() {
        return KvCoder.of(keyTypeCoder(), CoGbkResult.CoGbkResultCoder.of(coGbkResultSchema(), UnionCoder.of((List) this.sources.stream().map(bucketedInput -> {
            return bucketedInput.fileOperations.getCoder();
        }).collect(Collectors.toList()))));
    }

    public void populateDisplayData(DisplayData.Builder builder) {
        super.populateDisplayData(builder);
        builder.add(DisplayData.item("targetParallelism", this.targetParallelism.toString()));
        builder.add(DisplayData.item("metricsKey", this.metricsKey));
    }

    public List<? extends BoundedSource<KV<KeyType, CoGbkResult>>> split(long j, PipelineOptions pipelineOptions) throws Exception {
        int numSplits = getNumSplits(getOrComputeSourceSpec(), this.effectiveParallelism, this.targetParallelism, getEstimatedSizeBytes(pipelineOptions), j, DESIRED_SIZE_BYTES_ADJUSTMENT_FACTOR.doubleValue());
        long longValue = this.estimatedSizeBytes.longValue() / numSplits;
        DecimalFormat decimalFormat = new DecimalFormat("0.00");
        Logger logger = LOG;
        Object[] objArr = new Object[4];
        objArr[0] = this.effectiveParallelism > 1 ? "further " : "";
        objArr[1] = decimalFormat.format(this.estimatedSizeBytes.longValue() / 1000000.0d);
        objArr[2] = Integer.valueOf(numSplits);
        objArr[3] = decimalFormat.format(longValue / 1000000.0d);
        logger.info("Parallelism was adjusted by {}splitting source of size {} MB into {} source(s) of size {} MB", objArr);
        int i = numSplits * this.effectiveParallelism;
        return (List) IntStream.range(0, numSplits).boxed().map(num -> {
            return createSplitSource(num.intValue(), i, longValue);
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int getNumSplits(SourceSpec sourceSpec, int i, TargetParallelism targetParallelism, long j, long j2, double d) {
        if (i == sourceSpec.greatestNumBuckets) {
            LOG.info("Parallelism is already maxed, can't split further.");
            return 1;
        }
        if (!targetParallelism.isAuto()) {
            int parallelism = sourceSpec.getParallelism(targetParallelism);
            LOG.info("Splitting using specified parallelism: " + targetParallelism);
            return parallelism;
        }
        int round = (int) Math.round(j / (((long) (j2 * d)) * 1.0d));
        if (round > 1) {
            return Math.min(Integer.highestOneBit(round - 1) * 2, sourceSpec.greatestNumBuckets);
        }
        LOG.info("Desired byte size is <= total input size, can't split further.");
        return 1;
    }

    public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) throws Exception {
        if (this.estimatedSizeBytes == null) {
            this.estimatedSizeBytes = Long.valueOf(this.sources.parallelStream().mapToLong((v0) -> {
                return v0.getOrSampleByteSize();
            }).sum());
            LOG.info("Estimated byte size is " + this.estimatedSizeBytes);
        }
        return this.estimatedSizeBytes.longValue();
    }

    public BoundedSource.BoundedReader<KV<KeyType, CoGbkResult>> createReader(PipelineOptions pipelineOptions) throws IOException {
        return new MergeBucketsReader(new MultiSourceKeyGroupReader(this.sources, toKeyFn(), coGbkResultSchema(), this.sources.get(0).getSourceMetadata().mapping.values().stream().findAny().get().metadata, comparator(), this.keyGroupSize, true, this.bucketOffsetId, this.effectiveParallelism, pipelineOptions), this);
    }
}
