package org.apache.beam.sdk.extensions.smb;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.smb.BucketShardId;
import org.apache.beam.sdk.extensions.smb.FileOperations;
import org.apache.beam.sdk.extensions.smb.SMBFilenamePolicy;
import org.apache.beam.sdk.extensions.smb.SortedBucketSink;
import org.apache.beam.sdk.extensions.smb.SortedBucketSource;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.fs.ResourceIdCoder;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes;

/* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketTransform.class */
public class SortedBucketTransform<FinalKeyT, FinalValueT> extends PTransform<PBegin, SortedBucketSink.WriteResult> {
    private final SMBFilenamePolicy filenamePolicy;
    private final ResourceId tempDirectory;
    private final FileOperations<FinalValueT> fileOperations;
    private final Class<FinalKeyT> finalKeyClass;
    private final List<SortedBucketSource.BucketedInput<?, ?>> sources;
    private final BucketMetadata<FinalKeyT, FinalValueT> bucketMetadata;
    private final TransformFn<FinalKeyT, FinalValueT> transformFn;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketTransform$MergeAndWriteBuckets.class */
    public static class MergeAndWriteBuckets<FinalKeyT, FinalValueT> extends DoFn<Integer, KV<BucketShardId, ResourceId>> {
        private static final Comparator<byte[]> bytesComparator = UnsignedBytes.lexicographicalComparator();
        private static final Comparator<Map.Entry<TupleTag, KV<byte[], Iterator<?>>>> keyComparator = (entry, entry2) -> {
            return bytesComparator.compare((byte[]) ((KV) entry.getValue()).getKey(), (byte[]) ((KV) entry2.getValue()).getKey());
        };
        private final List<SortedBucketSource.BucketedInput<?, ?>> sources;
        private final SMBFilenamePolicy.FileAssignment fileAssignment;
        private final FileOperations<FinalValueT> fileOperations;
        private final TransformFn<FinalKeyT, FinalValueT> transformFn;
        private final BucketMetadata<FinalKeyT, FinalValueT> bucketMetadata;
        private final Coder<FinalKeyT> keyCoder;
        private final int leastNumBuckets;
        private final Counter elementsWritten;
        private final Counter elementsRead;
        private final Distribution keyGroupSize;

        MergeAndWriteBuckets(String str, List<SortedBucketSource.BucketedInput<?, ?>> list, SourceSpec<FinalKeyT> sourceSpec, SMBFilenamePolicy.FileAssignment fileAssignment, FileOperations<FinalValueT> fileOperations, BucketMetadata<FinalKeyT, FinalValueT> bucketMetadata, TransformFn<FinalKeyT, FinalValueT> transformFn) {
            this.sources = list;
            this.fileAssignment = fileAssignment;
            this.fileOperations = fileOperations;
            this.transformFn = transformFn;
            this.bucketMetadata = bucketMetadata;
            this.keyCoder = sourceSpec.keyCoder;
            this.leastNumBuckets = sourceSpec.leastNumBuckets;
            this.elementsWritten = Metrics.counter(SortedBucketTransform.class, str + "-ElementsWritten");
            this.elementsRead = Metrics.counter(SortedBucketTransform.class, str + "-ElementsRead");
            this.keyGroupSize = Metrics.distribution(SortedBucketTransform.class, str + "-KeyGroupSize");
        }

        public void populateDisplayData(DisplayData.Builder builder) {
            super.populateDisplayData(builder);
            builder.add(DisplayData.item("keyCoder", this.keyCoder.getClass()));
            builder.add(DisplayData.item("numBuckets", Integer.valueOf(this.bucketMetadata.getNumBuckets())));
            builder.add(DisplayData.item("numShards", Integer.valueOf(this.bucketMetadata.getNumShards())));
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<Integer, KV<BucketShardId, ResourceId>>.ProcessContext processContext) {
            int intValue = ((Integer) processContext.element()).intValue();
            boolean z = this.bucketMetadata.getNumBuckets() > this.leastNumBuckets;
            HashMap hashMap = new HashMap();
            ArrayList arrayList = new ArrayList();
            int i = intValue;
            while (true) {
                int i2 = i;
                if (i2 >= this.bucketMetadata.getNumBuckets()) {
                    merge((KeyGroupIterator[]) this.sources.stream().map(bucketedInput -> {
                        return bucketedInput.createIterator(((Integer) processContext.element()).intValue(), this.leastNumBuckets);
                    }).toArray(i3 -> {
                        return new KeyGroupIterator[i3];
                    }), SortedBucketSource.BucketedInput.schemaOf(this.sources), kv -> {
                        try {
                            this.transformFn.writeTransform(KV.of(this.keyCoder.decode(new ByteArrayInputStream((byte[]) kv.getKey())), (CoGbkResult) kv.getValue()), (SerializableConsumer) hashMap.get(Integer.valueOf(z ? this.bucketMetadata.getBucketId((byte[]) kv.getKey()) : intValue)));
                        } catch (Exception e) {
                            throw new RuntimeException("Failed to decode and merge key group", e);
                        }
                    }, this.elementsRead, this.keyGroupSize);
                    arrayList.forEach(kv2 -> {
                        ((OutputCollector) hashMap.get(Integer.valueOf(((BucketShardId) kv2.getKey()).getBucketId()))).onComplete();
                        processContext.output(kv2);
                    });
                    return;
                }
                BucketShardId of = BucketShardId.of(i2, 0);
                ResourceId forBucket = this.fileAssignment.forBucket(of, this.bucketMetadata);
                try {
                    hashMap.put(Integer.valueOf(i2), new OutputCollector(this.fileOperations.createWriter(forBucket), this.elementsWritten));
                    arrayList.add(KV.of(of, forBucket));
                    i = i2 + this.leastNumBuckets;
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }

        static void merge(KeyGroupIterator[] keyGroupIteratorArr, CoGbkResultSchema coGbkResultSchema, Consumer<KV<byte[], CoGbkResult>> consumer, Counter counter, Distribution distribution) {
            int i;
            int length = keyGroupIteratorArr.length;
            TupleTagList tupleTagList = coGbkResultSchema.getTupleTagList();
            HashMap hashMap = new HashMap();
            do {
                i = 0;
                for (int i2 = 0; i2 < length; i2++) {
                    KeyGroupIterator keyGroupIterator = keyGroupIteratorArr[i2];
                    if (!hashMap.containsKey(tupleTagList.get(i2))) {
                        if (keyGroupIterator.hasNext()) {
                            hashMap.put(tupleTagList.get(i2), keyGroupIterator.next());
                        } else {
                            i++;
                        }
                    }
                }
                if (hashMap.isEmpty()) {
                    return;
                }
                Map.Entry<TupleTag, KV<byte[], Iterator<?>>> entry = (Map.Entry) hashMap.entrySet().stream().min(keyComparator).orElse(null);
                Iterator it = hashMap.entrySet().iterator();
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < coGbkResultSchema.size(); i3++) {
                    arrayList.add(new ArrayList());
                }
                int i4 = 0;
                while (it.hasNext()) {
                    Map.Entry<TupleTag, KV<byte[], Iterator<?>>> entry2 = (Map.Entry) it.next();
                    if (keyComparator.compare(entry2, entry) == 0) {
                        List list = (List) arrayList.get(coGbkResultSchema.getIndex(entry2.getKey()));
                        Iterator it2 = (Iterator) entry2.getValue().getValue();
                        Objects.requireNonNull(list);
                        it2.forEachRemaining(list::add);
                        it.remove();
                        i4 += list.size();
                    }
                }
                distribution.update(i4);
                counter.inc(i4);
                consumer.accept(KV.of((byte[]) entry.getValue().getKey(), CoGbkResultUtil.newCoGbkResult(coGbkResultSchema, arrayList)));
            } while (i != length);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketTransform$OutputCollector.class */
    public static class OutputCollector<ValueT> implements SerializableConsumer<ValueT> {
        private final FileOperations.Writer<ValueT> writer;
        private final Counter elementsWritten;

        OutputCollector(FileOperations.Writer<ValueT> writer, Counter counter) {
            this.writer = writer;
            this.elementsWritten = counter;
        }

        void onComplete() {
            try {
                this.writer.close();
            } catch (IOException e) {
                throw new RuntimeException("Closing writer failed: ", e);
            }
        }

        @Override // java.util.function.Consumer
        public void accept(ValueT valuet) {
            try {
                this.writer.write(valuet);
                this.elementsWritten.inc();
            } catch (IOException e) {
                throw new RuntimeException("Write of element " + valuet + " failed: ", e);
            }
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketTransform$SerializableConsumer.class */
    public interface SerializableConsumer<ValueT> extends Consumer<ValueT>, Serializable {
    }

    @FunctionalInterface
    /* loaded from: input_file:org/apache/beam/sdk/extensions/smb/SortedBucketTransform$TransformFn.class */
    public interface TransformFn<KeyT, ValueT> extends Serializable {
        void writeTransform(KV<KeyT, CoGbkResult> kv, SerializableConsumer<ValueT> serializableConsumer);
    }

    public SortedBucketTransform(Class<FinalKeyT> cls, BucketMetadata<FinalKeyT, FinalValueT> bucketMetadata, ResourceId resourceId, ResourceId resourceId2, String str, FileOperations<FinalValueT> fileOperations, List<SortedBucketSource.BucketedInput<?, ?>> list, TransformFn<FinalKeyT, FinalValueT> transformFn) {
        this.filenamePolicy = new SMBFilenamePolicy(resourceId, str);
        this.tempDirectory = resourceId2;
        this.fileOperations = fileOperations;
        this.finalKeyClass = cls;
        this.sources = list;
        this.transformFn = transformFn;
        this.bucketMetadata = bucketMetadata;
    }

    public final SortedBucketSink.WriteResult expand(PBegin pBegin) {
        Preconditions.checkArgument(this.bucketMetadata.getNumShards() == 1, "Sharding is not supported in SortedBucketTransform. numShards must == 1.");
        SourceSpec from = SourceSpec.from(this.finalKeyClass, this.sources);
        Preconditions.checkArgument(this.bucketMetadata.getNumBuckets() >= from.leastNumBuckets, "numBuckets in BucketMetadata must be >= leastNumBuckets among sources: " + from.leastNumBuckets);
        return SortedBucketSink.WriteResult.fromTuple(pBegin.getPipeline().apply("CreateBuckets", Create.of((Iterable) IntStream.range(0, from.leastNumBuckets).boxed().collect(Collectors.toList())).withCoder(VarIntCoder.of())).apply("ReshuffleKeys", Reshuffle.viaRandomKey()).apply("MergeTransformWrite", ParDo.of(new MergeAndWriteBuckets(getName(), this.sources, from, this.filenamePolicy.forTempFiles(this.tempDirectory), this.fileOperations, this.bucketMetadata, this.transformFn))).setCoder(KvCoder.of(BucketShardId.BucketShardIdCoder.of(), ResourceIdCoder.of())).apply("FinalizeTempFiles", new SortedBucketSink.RenameBuckets(this.filenamePolicy.forDestination(), this.bucketMetadata, this.fileOperations)));
    }
}
