package org.apache.beam.sdk.extensions.smb;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.smb.BucketMetadataUtil;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
import org.apache.beam.sdk.transforms.join.UnionCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes;

/* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource.class */
public class SortedBucketSource<FinalKeyT> extends PTransform<PBegin, PCollection<KV<FinalKeyT, CoGbkResult>>> {
    private static final Comparator<byte[]> bytesComparator = UnsignedBytes.lexicographicalComparator();
    private final Class<FinalKeyT> finalKeyClass;
    private final transient List<BucketedInput<?, ?>> sources;

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$BucketedInput.class */
    public static class BucketedInput<K, V> implements Serializable {
        private TupleTag<V> tupleTag;
        private String filenameSuffix;
        private FileOperations<V> fileOperations;
        private List<ResourceId> inputDirectories;
        private transient BucketMetadataUtil.SourceMetadata<K, V> sourceMetadata;

        public BucketedInput(TupleTag<V> tupleTag, ResourceId resourceId, String str, FileOperations<V> fileOperations) {
            this(tupleTag, (List<ResourceId>) Collections.singletonList(resourceId), str, fileOperations);
        }

        public BucketedInput(TupleTag<V> tupleTag, List<ResourceId> list, String str, FileOperations<V> fileOperations) {
            this.tupleTag = tupleTag;
            this.filenameSuffix = str;
            this.fileOperations = fileOperations;
            this.inputDirectories = list;
        }

        public TupleTag<V> getTupleTag() {
            return this.tupleTag;
        }

        public Coder<V> getCoder() {
            return this.fileOperations.getCoder();
        }

        static CoGbkResultSchema schemaOf(List<BucketedInput<?, ?>> list) {
            return CoGbkResultSchema.of((List) list.stream().map((v0) -> {
                return v0.getTupleTag();
            }).collect(Collectors.toList()));
        }

        public BucketMetadata<K, V> getMetadata() {
            computeMetadataIfAbsent();
            return this.sourceMetadata.getCanonicalMetadata();
        }

        private void computeMetadataIfAbsent() {
            if (this.sourceMetadata != null) {
                return;
            }
            this.sourceMetadata = BucketMetadataUtil.get().getSourceMetadata(this.inputDirectories, this.filenameSuffix);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public KeyGroupIterator<byte[], V> createIterator(int i, int i2) {
            ArrayList arrayList = new ArrayList();
            this.sourceMetadata.getPartitionMetadata().forEach((resourceId, partitionMetadata) -> {
                int numBuckets = partitionMetadata.getNumBuckets();
                int numShards = partitionMetadata.getNumShards();
                int i3 = i;
                while (true) {
                    int i4 = i3;
                    if (i4 >= numBuckets) {
                        return;
                    }
                    for (int i5 = 0; i5 < numShards; i5++) {
                        try {
                            arrayList.add(this.fileOperations.iterator(partitionMetadata.getFileAssignment().forBucket(BucketShardId.of(i4, i5), numBuckets, numShards)));
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    i3 = i4 + i2;
                }
            });
            BucketMetadata<K, V> canonicalMetadata = this.sourceMetadata.getCanonicalMetadata();
            Objects.requireNonNull(canonicalMetadata);
            return new KeyGroupIterator<>(arrayList, canonicalMetadata::getKeyBytes, SortedBucketSource.bytesComparator);
        }

        public String toString() {
            Object[] objArr = new Object[3];
            objArr[0] = this.tupleTag.getId();
            objArr[1] = this.inputDirectories.size() > 5 ? this.inputDirectories.subList(0, 4) + "..." + this.inputDirectories.get(this.inputDirectories.size() - 1) : this.inputDirectories;
            objArr[2] = this.sourceMetadata.getCanonicalMetadata();
            return String.format("BucketedInput[tupleTag=%s, inputDirectories=[%s], metadata=%s]", objArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            SerializableCoder.of(TupleTag.class).encode(this.tupleTag, objectOutputStream);
            StringUtf8Coder.of().encode(this.filenameSuffix, objectOutputStream);
            SerializableCoder.of(FileOperations.class).encode(this.fileOperations, objectOutputStream);
            NullableCoder.of(SerializableCoder.of(BucketMetadataUtil.SourceMetadata.class)).encode(this.sourceMetadata, objectOutputStream);
        }

        private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
            this.tupleTag = SerializableCoder.of(TupleTag.class).decode(objectInputStream);
            this.filenameSuffix = StringUtf8Coder.of().decode(objectInputStream);
            this.fileOperations = (FileOperations) SerializableCoder.of(FileOperations.class).decode(objectInputStream);
            this.sourceMetadata = (BucketMetadataUtil.SourceMetadata) NullableCoder.of(SerializableCoder.of(BucketMetadataUtil.SourceMetadata.class)).decode(objectInputStream);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$MergeBuckets.class */
    public static class MergeBuckets<FinalKeyT> extends DoFn<Integer, KV<FinalKeyT, CoGbkResult>> {
        private static final Comparator<Map.Entry<TupleTag, KV<byte[], Iterator<?>>>> keyComparator = (entry, entry2) -> {
            return SortedBucketSource.bytesComparator.compare((byte[]) ((KV) entry.getValue()).getKey(), (byte[]) ((KV) entry2.getValue()).getKey());
        };
        private final Integer leastNumBuckets;
        private final Coder<FinalKeyT> keyCoder;
        private final List<BucketedInput<?, ?>> sources;
        private final Counter elementsRead;
        private final Distribution keyGroupSize;

        MergeBuckets(String str, List<BucketedInput<?, ?>> list, int i, Coder<FinalKeyT> coder) {
            this.leastNumBuckets = Integer.valueOf(i);
            this.keyCoder = coder;
            this.sources = list;
            this.elementsRead = Metrics.counter(SortedBucketSource.class, str + "-ElementsRead");
            this.keyGroupSize = Metrics.distribution(SortedBucketSource.class, str + "-KeyGroupSize");
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<Integer, KV<FinalKeyT, CoGbkResult>>.ProcessContext processContext) {
            merge(((Integer) processContext.element()).intValue(), this.sources, this.leastNumBuckets.intValue(), kv -> {
                try {
                    processContext.output(KV.of(this.keyCoder.decode(new ByteArrayInputStream((byte[]) kv.getKey())), (CoGbkResult) kv.getValue()));
                } catch (Exception e) {
                    throw new RuntimeException("Failed to decode and merge key group", e);
                }
            }, this.elementsRead, this.keyGroupSize);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static void merge(int i, List<BucketedInput<?, ?>> list, int i2, Consumer<KV<byte[], CoGbkResult>> consumer, Counter counter, Distribution distribution) {
            int i3;
            int size = list.size();
            KeyGroupIterator[] keyGroupIteratorArr = (KeyGroupIterator[]) list.stream().map(bucketedInput -> {
                return bucketedInput.createIterator(i, i2);
            }).toArray(i4 -> {
                return new KeyGroupIterator[i4];
            });
            CoGbkResultSchema schemaOf = BucketedInput.schemaOf(list);
            TupleTagList tupleTagList = schemaOf.getTupleTagList();
            HashMap hashMap = new HashMap();
            do {
                i3 = 0;
                for (int i5 = 0; i5 < size; i5++) {
                    KeyGroupIterator keyGroupIterator = keyGroupIteratorArr[i5];
                    if (!hashMap.containsKey(tupleTagList.get(i5))) {
                        if (keyGroupIterator.hasNext()) {
                            hashMap.put(tupleTagList.get(i5), keyGroupIterator.next());
                        } else {
                            i3++;
                        }
                    }
                }
                if (hashMap.isEmpty()) {
                    return;
                }
                Map.Entry<TupleTag, KV<byte[], Iterator<?>>> entry = (Map.Entry) hashMap.entrySet().stream().min(keyComparator).orElse(null);
                Iterator it = hashMap.entrySet().iterator();
                ArrayList arrayList = new ArrayList();
                for (int i6 = 0; i6 < schemaOf.size(); i6++) {
                    arrayList.add(new ArrayList());
                }
                int i7 = 0;
                while (it.hasNext()) {
                    Map.Entry<TupleTag, KV<byte[], Iterator<?>>> entry2 = (Map.Entry) it.next();
                    if (keyComparator.compare(entry2, entry) == 0) {
                        List list2 = (List) arrayList.get(schemaOf.getIndex(entry2.getKey()));
                        Iterator it2 = (Iterator) entry2.getValue().getValue();
                        Objects.requireNonNull(list2);
                        it2.forEachRemaining(list2::add);
                        it.remove();
                        i7 += list2.size();
                    }
                }
                distribution.update(i7);
                counter.inc(i7);
                consumer.accept(KV.of((byte[]) entry.getValue().getKey(), CoGbkResultUtil.newCoGbkResult(schemaOf, arrayList)));
            } while (i3 != size);
        }

        public void populateDisplayData(DisplayData.Builder builder) {
            super.populateDisplayData(builder);
            builder.add(DisplayData.item("keyCoder", this.keyCoder.getClass()));
            builder.add(DisplayData.item("leastNumBuckets", this.leastNumBuckets));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketSource$SourceSpec.class */
    public static class SourceSpec<K> {
        int leastNumBuckets;
        Coder<K> keyCoder;

        SourceSpec(int i, Coder<K> coder) {
            this.leastNumBuckets = i;
            this.keyCoder = coder;
        }
    }

    public SortedBucketSource(Class<FinalKeyT> cls, List<BucketedInput<?, ?>> list) {
        this.finalKeyClass = cls;
        this.sources = list;
    }

    public final PCollection<KV<FinalKeyT, CoGbkResult>> expand(PBegin pBegin) {
        SourceSpec sourceSpec = getSourceSpec(this.finalKeyClass, this.sources);
        return pBegin.getPipeline().apply("CreateBuckets", Create.of((Iterable) IntStream.range(0, sourceSpec.leastNumBuckets).boxed().collect(Collectors.toList())).withCoder(VarIntCoder.of())).apply("ReshuffleKeys", Reshuffle.viaRandomKey()).apply("MergeBuckets", ParDo.of(new MergeBuckets(getName(), this.sources, sourceSpec.leastNumBuckets, sourceSpec.keyCoder))).setCoder(KvCoder.of(sourceSpec.keyCoder, CoGbkResult.CoGbkResultCoder.of(BucketedInput.schemaOf(this.sources), UnionCoder.of((List) this.sources.stream().map((v0) -> {
            return v0.getCoder();
        }).collect(Collectors.toList())))));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <KeyT> SourceSpec<KeyT> getSourceSpec(Class<KeyT> cls, List<BucketedInput<?, ?>> list) {
        BucketMetadata<?, ?> bucketMetadata = null;
        Coder<?> coder = null;
        int i = Integer.MAX_VALUE;
        for (BucketedInput<?, ?> bucketedInput : list) {
            BucketMetadata<?, ?> metadata = bucketedInput.getMetadata();
            if (bucketMetadata == null) {
                bucketMetadata = metadata;
            } else {
                Preconditions.checkState(bucketMetadata.isCompatibleWith(metadata), "Source %s is incompatible with source %s", list.get(0), bucketedInput);
            }
            i = Math.min(metadata.getNumBuckets(), i);
            if (metadata.getKeyClass() == cls && coder == null) {
                try {
                    coder = metadata.getKeyCoder();
                } catch (Coder.NonDeterministicException e) {
                    throw new RuntimeException("Non-deterministic coder for key class " + cls, e);
                } catch (CannotProvideCoderException e2) {
                    throw new RuntimeException("Could not provide coder for key class " + cls, e2);
                }
            }
        }
        Preconditions.checkNotNull(coder, "Could not infer coder for key class %s", cls);
        return new SourceSpec<>(i, coder);
    }
}
