package org.apache.beam.sdk.extensions.smb;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.io.ByteArrayInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.extensions.smb.BucketMetadataUtil;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.ResolveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.fs.ResourceIdCoder;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
import org.apache.beam.sdk.transforms.join.UnionCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource.class */
public class SortedBucketSource<FinalKeyT> extends BoundedSource<KV<FinalKeyT, CoGbkResult>> {
    static final Double DESIRED_SIZE_BYTES_ADJUSTMENT_FACTOR = Double.valueOf(0.5d);
    private static final AtomicInteger metricsId = new AtomicInteger(1);
    private static final Comparator<byte[]> bytesComparator = UnsignedBytes.lexicographicalComparator();
    private static final Logger LOG = LoggerFactory.getLogger(SortedBucketSource.class);
    private final Class<FinalKeyT> finalKeyClass;
    private final List<BucketedInput<?, ?>> sources;
    private final TargetParallelism targetParallelism;
    private final int effectiveParallelism;
    private final int bucketOffsetId;
    private SourceSpec<FinalKeyT> sourceSpec;
    private final Distribution keyGroupSize;
    private Long estimatedSizeBytes;
    private final String metricsKey;

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$BucketedInput.class */
    public static class BucketedInput<K, V> implements Serializable {
        private static final Pattern BUCKET_PATTERN = Pattern.compile("(\\d+)-of-(\\d+)");
        private TupleTag<V> tupleTag;
        private String filenameSuffix;
        private FileOperations<V> fileOperations;
        private List<ResourceId> inputDirectories;
        private Predicate<V> predicate;
        private transient BucketMetadataUtil.SourceMetadata<K, V> sourceMetadata;

        public BucketedInput(TupleTag<V> tupleTag, ResourceId resourceId, String str, FileOperations<V> fileOperations) {
            this(tupleTag, (List<ResourceId>) Collections.singletonList(resourceId), str, fileOperations);
        }

        public BucketedInput(TupleTag<V> tupleTag, List<ResourceId> list, String str, FileOperations<V> fileOperations) {
            this(tupleTag, list, str, fileOperations, null);
        }

        public BucketedInput(TupleTag<V> tupleTag, List<ResourceId> list, String str, FileOperations<V> fileOperations, Predicate<V> predicate) {
            list.forEach(resourceId -> {
                Preconditions.checkArgument(resourceId.isDirectory(), "Cannot construct SMB source from non-directory input " + resourceId);
            });
            this.tupleTag = tupleTag;
            this.filenameSuffix = str;
            this.fileOperations = fileOperations;
            this.inputDirectories = list;
            this.predicate = predicate != null ? predicate : (list2, obj) -> {
                return true;
            };
        }

        public TupleTag<V> getTupleTag() {
            return this.tupleTag;
        }

        public Coder<V> getCoder() {
            return this.fileOperations.getCoder();
        }

        static CoGbkResultSchema schemaOf(List<BucketedInput<?, ?>> list) {
            return CoGbkResultSchema.of((List) list.stream().map((v0) -> {
                return v0.getTupleTag();
            }).collect(Collectors.toList()));
        }

        public BucketMetadata<K, V> getMetadata() {
            return getOrComputeMetadata().getCanonicalMetadata();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Map<ResourceId, BucketMetadataUtil.PartitionMetadata> getPartitionMetadata() {
            return getOrComputeMetadata().getPartitionMetadata();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public BucketMetadataUtil.SourceMetadata<K, V> getOrComputeMetadata() {
            if (this.sourceMetadata == null) {
                this.sourceMetadata = BucketMetadataUtil.get().getSourceMetadata(this.inputDirectories, this.filenameSuffix);
            }
            return this.sourceMetadata;
        }

        private static List<MatchResult.Metadata> sampleDirectory(ResourceId resourceId, String str) {
            try {
                return FileSystems.match(resourceId.resolve(str, ResolveOptions.StandardResolveOptions.RESOLVE_FILE).toString()).metadata();
            } catch (FileNotFoundException e) {
                return Collections.emptyList();
            } catch (IOException e2) {
                throw new RuntimeException("Exception fetching metadata for " + resourceId, e2);
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getOrSampleByteSize() {
            return this.inputDirectories.parallelStream().mapToLong(resourceId -> {
                List<MatchResult.Metadata> sampleDirectory = sampleDirectory(resourceId, "*-0000?-of-?????" + this.filenameSuffix);
                if (sampleDirectory.isEmpty()) {
                    sampleDirectory = sampleDirectory(resourceId, "*-0000?-of-*-shard-00000-of-?????" + this.filenameSuffix);
                }
                int i = 0;
                long j = 0;
                HashSet hashSet = new HashSet();
                for (MatchResult.Metadata metadata : sampleDirectory) {
                    Matcher matcher = BUCKET_PATTERN.matcher(metadata.resourceId().getFilename());
                    if (!matcher.find()) {
                        throw new RuntimeException("Couldn't match bucket information from filename: " + metadata.resourceId().getFilename());
                    }
                    hashSet.add(matcher.group(1));
                    if (i == 0) {
                        i = Integer.parseInt(matcher.group(2));
                    }
                    j += metadata.sizeBytes();
                }
                if (i == 0) {
                    throw new IllegalArgumentException("Directory " + resourceId + " has no bucket files");
                }
                return hashSet.size() < i ? (long) (j * (i / (hashSet.size() * 1.0d))) : j;
            }).sum();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public KeyGroupIterator<byte[], V> createIterator(int i, int i2) {
            List<T> mapBucketFiles = mapBucketFiles(i, i2, resourceId -> {
                try {
                    return this.fileOperations.iterator(resourceId);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
            BucketMetadata<K, V> canonicalMetadata = this.sourceMetadata.getCanonicalMetadata();
            Objects.requireNonNull(canonicalMetadata);
            return new KeyGroupIterator<>(mapBucketFiles, canonicalMetadata::getKeyBytes, SortedBucketSource.bytesComparator);
        }

        private <T> List<T> mapBucketFiles(int i, int i2, Function<ResourceId, T> function) {
            ArrayList arrayList = new ArrayList();
            getPartitionMetadata().forEach((resourceId, partitionMetadata) -> {
                int numBuckets = partitionMetadata.getNumBuckets();
                int numShards = partitionMetadata.getNumShards();
                int i3 = i % numBuckets;
                while (true) {
                    int i4 = i3;
                    if (i4 >= numBuckets) {
                        return;
                    }
                    for (int i5 = 0; i5 < numShards; i5++) {
                        arrayList.add(function.apply(partitionMetadata.getFileAssignment().forBucket(BucketShardId.of(i4, i5), numBuckets, numShards)));
                    }
                    i3 = i4 + i2;
                }
            });
            return arrayList;
        }

        public String toString() {
            Object[] objArr = new Object[2];
            objArr[0] = this.tupleTag.getId();
            objArr[1] = this.inputDirectories.size() > 5 ? this.inputDirectories.subList(0, 4) + "..." + this.inputDirectories.get(this.inputDirectories.size() - 1) : this.inputDirectories;
            return String.format("BucketedInput[tupleTag=%s, inputDirectories=[%s]]", objArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            SerializableCoder.of(TupleTag.class).encode(this.tupleTag, objectOutputStream);
            StringUtf8Coder.of().encode(this.filenameSuffix, objectOutputStream);
            SerializableCoder.of(FileOperations.class).encode(this.fileOperations, objectOutputStream);
            ListCoder.of(ResourceIdCoder.of()).encode(this.inputDirectories, objectOutputStream);
            SerializableCoder.of(Predicate.class).encode(this.predicate, objectOutputStream);
        }

        private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
            this.tupleTag = SerializableCoder.of(TupleTag.class).decode(objectInputStream);
            this.filenameSuffix = StringUtf8Coder.of().decode(objectInputStream);
            this.fileOperations = (FileOperations) SerializableCoder.of(FileOperations.class).decode(objectInputStream);
            this.inputDirectories = (List) ListCoder.of(ResourceIdCoder.of()).decode(objectInputStream);
            this.predicate = (Predicate) SerializableCoder.of(Predicate.class).decode(objectInputStream);
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -218119931:
                    if (implMethodName.equals("lambda$new$858364bc$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case BucketMetadata.CURRENT_VERSION /* 0 */:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/sdk/extensions/smb/SortedBucketSource$Predicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/sdk/extensions/smb/SortedBucketSource$BucketedInput") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/lang/Object;)Ljava/lang/Boolean;")) {
                        return (list2, obj) -> {
                            return true;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$MergeBucketsReader.class */
    public static class MergeBucketsReader<FinalKeyT> extends BoundedSource.BoundedReader<KV<FinalKeyT, CoGbkResult>> {
        private static final Comparator<Map.Entry<TupleTag, KV<byte[], Iterator<?>>>> keyComparator = (entry, entry2) -> {
            return SortedBucketSource.bytesComparator.compare((byte[]) ((KV) entry.getValue()).getKey(), (byte[]) ((KV) entry2.getValue()).getKey());
        };
        private final Coder<FinalKeyT> keyCoder;
        private final SortedBucketSource<FinalKeyT> currentSource;
        private final Distribution keyGroupSize;
        private final int numSources;
        private final int parallelism;
        private final KeyGroupIterator[] iterators;
        private final Function<byte[], Boolean> keyGroupFilter;
        private final Predicate[] predicates;
        private final CoGbkResultSchema resultSchema;
        private final TupleTagList tupleTags;
        private final Map<TupleTag, Integer> bucketsPerSource;
        private KV<byte[], CoGbkResult> next;
        private Map<TupleTag, KV<byte[], Iterator<?>>> nextKeyGroups;

        /* JADX INFO: Access modifiers changed from: package-private */
        public MergeBucketsReader(List<BucketedInput<?, ?>> list, Integer num, int i, SourceSpec<FinalKeyT> sourceSpec, SortedBucketSource<FinalKeyT> sortedBucketSource, Distribution distribution) {
            this.keyCoder = sourceSpec.keyCoder;
            this.numSources = list.size();
            this.currentSource = sortedBucketSource;
            this.keyGroupSize = distribution;
            this.parallelism = i;
            this.keyGroupFilter = bArr -> {
                return Boolean.valueOf(((BucketedInput) list.get(0)).getMetadata().rehashBucket(bArr, i) == num.intValue());
            };
            this.predicates = (Predicate[]) list.stream().map(bucketedInput -> {
                return bucketedInput.predicate;
            }).toArray(i2 -> {
                return new Predicate[i2];
            });
            this.iterators = (KeyGroupIterator[]) list.stream().map(bucketedInput2 -> {
                return bucketedInput2.createIterator(num.intValue(), i);
            }).toArray(i3 -> {
                return new KeyGroupIterator[i3];
            });
            this.resultSchema = BucketedInput.schemaOf(list);
            this.tupleTags = this.resultSchema.getTupleTagList();
            this.bucketsPerSource = (Map) list.stream().collect(Collectors.toMap((v0) -> {
                return v0.getTupleTag();
            }, bucketedInput3 -> {
                return Integer.valueOf(bucketedInput3.getOrComputeMetadata().getCanonicalMetadata().getNumBuckets());
            }));
        }

        public boolean start() throws IOException {
            this.nextKeyGroups = new HashMap();
            return advance();
        }

        /* renamed from: getCurrent, reason: merged with bridge method [inline-methods] */
        public KV<FinalKeyT, CoGbkResult> m23getCurrent() throws NoSuchElementException {
            try {
                return KV.of(this.keyCoder.decode(new ByteArrayInputStream((byte[]) this.next.getKey())), (CoGbkResult) this.next.getValue());
            } catch (Exception e) {
                throw new RuntimeException("Failed to decode key group", e);
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public boolean advance() throws IOException {
            int i;
            do {
                i = 0;
                for (int i2 = 0; i2 < this.numSources; i2++) {
                    KeyGroupIterator keyGroupIterator = this.iterators[i2];
                    if (!this.nextKeyGroups.containsKey(this.tupleTags.get(i2))) {
                        if (keyGroupIterator.hasNext()) {
                            this.nextKeyGroups.put(this.tupleTags.get(i2), keyGroupIterator.next());
                        } else {
                            i++;
                        }
                    }
                }
                if (this.nextKeyGroups.isEmpty()) {
                    return false;
                }
                Map.Entry<TupleTag, KV<byte[], Iterator<?>>> orElse = this.nextKeyGroups.entrySet().stream().min(keyComparator).orElse(null);
                Iterator<Map.Entry<TupleTag, KV<byte[], Iterator<?>>>> it = this.nextKeyGroups.entrySet().iterator();
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < this.resultSchema.size(); i3++) {
                    arrayList.add(new ArrayList());
                }
                int i4 = 0;
                boolean z = -1;
                while (it.hasNext()) {
                    Map.Entry<TupleTag, KV<byte[], Iterator<?>>> next = it.next();
                    if (keyComparator.compare(next, orElse) == 0) {
                        TupleTag key = next.getKey();
                        int index = this.resultSchema.getIndex(key);
                        List list = (List) arrayList.get(index);
                        Predicate predicate = this.predicates[index];
                        if (z) {
                            ((Iterator) next.getValue().getValue()).forEachRemaining(obj -> {
                                if (predicate.apply(list, obj).booleanValue()) {
                                    list.add(obj);
                                }
                            });
                        } else if (z != -1 || (this.bucketsPerSource.get(key).intValue() < this.parallelism && !this.keyGroupFilter.apply((byte[]) orElse.getValue().getKey()).booleanValue())) {
                            ((Iterator) next.getValue().getValue()).forEachRemaining(obj2 -> {
                            });
                            z = false;
                        } else {
                            ((Iterator) next.getValue().getValue()).forEachRemaining(obj3 -> {
                                if (predicate.apply(list, obj3).booleanValue()) {
                                    list.add(obj3);
                                }
                            });
                            z = true;
                        }
                        i4 += list.size();
                        it.remove();
                    }
                }
                if (z) {
                    this.keyGroupSize.update(i4);
                    this.next = KV.of((byte[]) orElse.getValue().getKey(), CoGbkResultUtil.newCoGbkResult(this.resultSchema, arrayList));
                    return true;
                }
            } while (i != this.numSources);
            return false;
        }

        public void close() throws IOException {
        }

        /* renamed from: getCurrentSource, reason: merged with bridge method [inline-methods] */
        public BoundedSource<KV<FinalKeyT, CoGbkResult>> m22getCurrentSource() {
            return this.currentSource;
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$Predicate.class */
    public interface Predicate<T> extends BiFunction<List<T>, T, Boolean>, Serializable {
    }

    public SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list) {
        this(cls, list, TargetParallelism.auto());
    }

    public SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list, TargetParallelism targetParallelism) {
        this(cls, list, targetParallelism, 0, 1, getDefaultMetricsKey());
    }

    public SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list, TargetParallelism targetParallelism, String str) {
        this(cls, list, targetParallelism, 0, 1, str);
    }

    private SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list, TargetParallelism targetParallelism, int i, int i2, String str) {
        this(cls, list, targetParallelism, i, i2, str, null);
    }

    private SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list, TargetParallelism targetParallelism, int i, int i2, String str, Long l) {
        this.finalKeyClass = cls;
        this.sources = list;
        this.targetParallelism = targetParallelism;
        this.bucketOffsetId = i;
        this.effectiveParallelism = i2;
        this.metricsKey = str;
        this.keyGroupSize = Metrics.distribution(SortedBucketSource.class, str + "-KeyGroupSize");
        this.estimatedSizeBytes = l;
    }

    private static String getDefaultMetricsKey() {
        int andAdd = metricsId.getAndAdd(1);
        return andAdd != 1 ? "SortedBucketSource{" + andAdd + "}" : "SortedBucketSource";
    }

    @VisibleForTesting
    int getBucketOffset() {
        return this.bucketOffsetId;
    }

    @VisibleForTesting
    int getEffectiveParallelism() {
        return this.effectiveParallelism;
    }

    private SourceSpec<FinalKeyT> getOrComputeSourceSpec() {
        if (this.sourceSpec == null) {
            this.sourceSpec = SourceSpec.from(this.finalKeyClass, this.sources);
        }
        return this.sourceSpec;
    }

    public Coder<KV<FinalKeyT, CoGbkResult>> getOutputCoder() {
        return KvCoder.of(getOrComputeSourceSpec().keyCoder, CoGbkResult.CoGbkResultCoder.of(BucketedInput.schemaOf(this.sources), UnionCoder.of((List) this.sources.stream().map((v0) -> {
            return v0.getCoder();
        }).collect(Collectors.toList()))));
    }

    public void populateDisplayData(DisplayData.Builder builder) {
        super.populateDisplayData(builder);
        builder.add(DisplayData.item("targetParallelism", this.targetParallelism.toString()));
        builder.add(DisplayData.item("keyClass", this.finalKeyClass.toString()));
        builder.add(DisplayData.item("metricsKey", this.metricsKey));
    }

    public List<? extends BoundedSource<KV<FinalKeyT, CoGbkResult>>> split(long j, PipelineOptions pipelineOptions) throws Exception {
        int numSplits = getNumSplits(getOrComputeSourceSpec(), this.effectiveParallelism, this.targetParallelism, getEstimatedSizeBytes(pipelineOptions), j, DESIRED_SIZE_BYTES_ADJUSTMENT_FACTOR.doubleValue());
        long longValue = this.estimatedSizeBytes.longValue() / numSplits;
        DecimalFormat decimalFormat = new DecimalFormat("0.00");
        Logger logger = LOG;
        Object[] objArr = new Object[4];
        objArr[0] = this.effectiveParallelism > 1 ? "further " : "";
        objArr[1] = decimalFormat.format(this.estimatedSizeBytes.longValue() / 1000000.0d);
        objArr[2] = Integer.valueOf(numSplits);
        objArr[3] = decimalFormat.format(longValue / 1000000.0d);
        logger.info("Parallelism was adjusted by {}splitting source of size {} MB into {} source(s) of size {} MB", objArr);
        int i = numSplits * this.effectiveParallelism;
        return (List) IntStream.range(0, numSplits).boxed().map(num -> {
            return new SortedBucketSource(this.finalKeyClass, this.sources, this.targetParallelism, this.bucketOffsetId + (num.intValue() * this.effectiveParallelism), i, this.metricsKey, Long.valueOf(longValue));
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int getNumSplits(SourceSpec sourceSpec, int i, TargetParallelism targetParallelism, long j, long j2, double d) {
        long j3 = (long) (j2 * d);
        int i2 = sourceSpec.greatestNumBuckets;
        if (i == i2) {
            LOG.info("Parallelism is already maxed, can't split further.");
            return 1;
        }
        if (!targetParallelism.isAuto()) {
            return sourceSpec.getParallelism(targetParallelism);
        }
        int round = (int) Math.round(j / (j3 * 1.0d));
        if (round > 1) {
            return Math.min(Integer.highestOneBit(round - 1) * 2, i2);
        }
        LOG.info("Desired byte size is <= total input size, can't split further.");
        return 1;
    }

    public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) throws Exception {
        if (this.estimatedSizeBytes == null) {
            this.estimatedSizeBytes = Long.valueOf(this.sources.parallelStream().mapToLong((v0) -> {
                return v0.getOrSampleByteSize();
            }).sum());
            LOG.info("Estimated byte size is " + this.estimatedSizeBytes);
        }
        return this.estimatedSizeBytes.longValue();
    }

    public BoundedSource.BoundedReader<KV<FinalKeyT, CoGbkResult>> createReader(PipelineOptions pipelineOptions) throws IOException {
        return new MergeBucketsReader(this.sources, Integer.valueOf(this.bucketOffsetId), this.effectiveParallelism, getOrComputeSourceSpec(), this, this.keyGroupSize);
    }
}
