package com.nvidia.spark.rapids;

import ai.rapids.cudf.HostMemoryBuffer;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.hadoop.conf.Configuration;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;
import software.amazon.awssdk.http.crt.TcpKeepAliveConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.utils.ThreadFactoryBuilder;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/nvidia/spark/rapids/RangeCopier.class */
public class RangeCopier {
    private static final Logger LOG = LoggerFactory.getLogger(RangeCopier.class);
    private static S3AsyncClient asyncClient;
    private static final String CREDENTIALS_PROVIDER = "fs.s3a.aws.credentials.provider";
    private static final String ACCESS_KEY = "fs.s3a.access.key";
    private static final String SECRET_KEY = "fs.s3a.secret.key";
    private static final String SESSION_KEY = "fs.s3a.session.key";
    private static final String AWSSDK = "software.amazon.awssdk";
    static volatile boolean useNetty;

    /* loaded from: input_file:com/nvidia/spark/rapids/RangeCopier$AsyncRangeRequestTransformer.class */
    private static class AsyncRangeRequestTransformer implements AsyncResponseTransformer<GetObjectResponse, Long> {
        private final HostMemoryBuffer hostMemoryBuffer;
        private final long outputOffset;
        private final long rangeLength;
        private volatile CompletableFuture<Long> resultFuture;

        /* loaded from: input_file:com/nvidia/spark/rapids/RangeCopier$AsyncRangeRequestTransformer$ByteBufferSubscriber.class */
        private static class ByteBufferSubscriber implements Subscriber<ByteBuffer> {
            private final CompletableFuture<Long> resultFuture;
            private final HostMemoryBuffer hostMemoryBuffer;
            private long pos;
            private final long limit;
            private long totalCopied;
            private Subscription byteBufferSubscription;

            public ByteBufferSubscriber(CompletableFuture<Long> completableFuture, HostMemoryBuffer hostMemoryBuffer, long j, long j2) {
                this.resultFuture = completableFuture;
                this.hostMemoryBuffer = hostMemoryBuffer;
                this.pos = j;
                this.limit = j + j2;
            }

            public void onSubscribe(Subscription subscription) {
                if (this.byteBufferSubscription != null) {
                    this.byteBufferSubscription.cancel();
                } else {
                    this.byteBufferSubscription = subscription;
                    this.byteBufferSubscription.request(Long.MAX_VALUE);
                }
            }

            public void onNext(ByteBuffer byteBuffer) {
                int remaining = byteBuffer.remaining();
                this.hostMemoryBuffer.asByteBuffer(this.pos, remaining).put(byteBuffer);
                this.pos += remaining;
                this.totalCopied += remaining;
                if (this.pos < this.limit) {
                    this.byteBufferSubscription.request(1L);
                } else if (this.pos > this.limit) {
                    this.resultFuture.completeExceptionally(new IllegalStateException("INFEASIBLE: Remaining zero bytes expected, read past the range by bytes: " + (this.pos - this.limit)));
                }
            }

            public void onError(Throwable th) {
                this.resultFuture.completeExceptionally(th);
            }

            public void onComplete() {
                this.resultFuture.complete(Long.valueOf(this.totalCopied));
            }
        }

        public AsyncRangeRequestTransformer(HostMemoryBuffer hostMemoryBuffer, long j, long j2) {
            this.hostMemoryBuffer = hostMemoryBuffer;
            this.outputOffset = j;
            this.rangeLength = j2;
        }

        public CompletableFuture<Long> prepare() {
            this.resultFuture = new CompletableFuture<>();
            return this.resultFuture;
        }

        public void onResponse(GetObjectResponse getObjectResponse) {
            RangeCopier.LOG.debug("Response available: {}", getObjectResponse);
        }

        public void onStream(SdkPublisher<ByteBuffer> sdkPublisher) {
            sdkPublisher.subscribe(new ByteBufferSubscriber(this.resultFuture, this.hostMemoryBuffer, this.outputOffset, this.rangeLength));
        }

        public void exceptionOccurred(Throwable th) {
            this.resultFuture.completeExceptionally(th);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/nvidia/spark/rapids/RangeCopier$Conf.class */
    public static class Conf {
        private final Configuration hadoopConf;
        final int maxConcurrency;
        final int maxTasks;
        final long threadsKeepAliveTime;
        final int maxThreads;
        final boolean connectionKeepAlive;
        final long connectionTTL;
        final boolean pathStyle;

        Conf(Configuration configuration) {
            RangeCopier.LOG.debug("Creating async S3 client conf from S3AFileSystem conf: {}", configuration);
            this.hadoopConf = configuration;
            this.maxConcurrency = configuration.getInt("fs.s3a.connection.maximum", 200);
            this.maxTasks = configuration.getInt("fs.s3a.max.total.tasks", 1000);
            this.threadsKeepAliveTime = configuration.getTimeDuration("fs.s3a.threads.keepalivetime", 60L, TimeUnit.SECONDS);
            this.maxThreads = configuration.getInt("fs.s3a.threads.max", 136);
            this.connectionKeepAlive = configuration.getBoolean("fs.s3a.connection.keepalive", true);
            this.connectionTTL = configuration.getTimeDuration("fs.s3a.connection.ttl", 5L, TimeUnit.MINUTES);
            this.pathStyle = configuration.getBoolean("fs.s3a.path.style.access", false);
        }

        AwsCredentialsProvider getAwsCredentialsProvider() {
            StaticCredentialsProvider create;
            RangeCopier.LOG.debug("Building AwsCredentialsProvider ...");
            String str = this.hadoopConf.get(RangeCopier.ACCESS_KEY);
            String str2 = this.hadoopConf.get(RangeCopier.SECRET_KEY);
            String str3 = this.hadoopConf.get(RangeCopier.SESSION_KEY);
            String str4 = this.hadoopConf.get(RangeCopier.CREDENTIALS_PROVIDER);
            if (str != null && str2 != null && str3 != null) {
                RangeCopier.LOG.debug("StaticCredentialsProvider using {}, {}, {}", new Object[]{RangeCopier.ACCESS_KEY, RangeCopier.SECRET_KEY, RangeCopier.SESSION_KEY});
                create = StaticCredentialsProvider.create(AwsSessionCredentials.create(str, str2, str3));
            } else if (str != null && str2 != null) {
                RangeCopier.LOG.debug("StaticCredentialsProvider using {}, {}", RangeCopier.ACCESS_KEY, RangeCopier.SECRET_KEY);
                create = StaticCredentialsProvider.create(AwsBasicCredentials.create(str, str2));
            } else if (str4 != null) {
                RangeCopier.LOG.debug("AwsCredentialsProviderChain using {}", RangeCopier.CREDENTIALS_PROVIDER);
                AwsCredentialsProviderChain.Builder builder = AwsCredentialsProviderChain.builder();
                for (Class<?> cls : this.hadoopConf.getClasses(RangeCopier.CREDENTIALS_PROVIDER, new Class[0])) {
                    try {
                        builder.addCredentialsProvider((AwsCredentialsProvider) MethodUtils.invokeStaticMethod(cls.getPackage().getName().startsWith("com.amazonaws.") ? Class.forName("software.amazon.awssdk.auth.credentials." + cls.getSimpleName()) : cls, "create", new Object[0]));
                    } catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                        throw new RuntimeException(e);
                    }
                }
                create = builder.build();
            } else {
                RangeCopier.LOG.warn("Missing fs.s3a access config, using default CredentialsProvider");
                create = DefaultCredentialsProvider.create();
            }
            RangeCopier.LOG.info("Configured CredentialsProvider object for S3 Client: {}", create.getClass().getName());
            return create;
        }
    }

    RangeCopier() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long copyToHMB(Configuration configuration, HostMemoryBuffer hostMemoryBuffer, URI uri, Iterable<RangeWithOffset> iterable) {
        String authority = uri.getAuthority();
        String substring = uri.getRawPath().substring(1);
        ArrayList arrayList = new ArrayList(3);
        for (RangeWithOffset rangeWithOffset : iterable) {
            long destOffset = rangeWithOffset.destOffset();
            arrayList.add(create(configuration).getObject((GetObjectRequest) GetObjectRequest.builder().bucket(authority).key(substring).range(rangeWithOffset.rangeSpec()).build(), new AsyncRangeRequestTransformer(hostMemoryBuffer, destOffset, rangeWithOffset.length())));
        }
        return ((Long) ((CompletableFuture) arrayList.stream().reduce(CompletableFuture.completedFuture(0L), (completableFuture, completableFuture2) -> {
            return completableFuture.thenCombine((CompletionStage) completableFuture2, (v0, v1) -> {
                return Long.sum(v0, v1);
            });
        })).join()).longValue();
    }

    private static synchronized S3AsyncClient create(Configuration configuration) {
        if (asyncClient == null) {
            LOG.debug("Initializing RAPIDS S3 Range Copier ...");
            Conf conf = new Conf(configuration);
            asyncClient = (S3AsyncClient) S3AsyncClient.builder().credentialsProvider(conf.getAwsCredentialsProvider()).forcePathStyle(Boolean.valueOf(conf.pathStyle)).httpClientBuilder(useNetty ? nettyBuilder(conf) : crtBuilder(conf)).asyncConfiguration(builder -> {
                builder.advancedOption(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, createThreadPoolExecutor(conf));
            }).build();
            LOG.debug("Done initializing RAPIDS S3 Range Copier: {}", asyncClient);
        }
        return asyncClient;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static synchronized void shutdown() {
        if (asyncClient != null) {
            try {
                S3AsyncClient s3AsyncClient = asyncClient;
                Throwable th = null;
                try {
                    LOG.debug("Closing client {}", asyncClient);
                    if (s3AsyncClient != null) {
                        if (0 != 0) {
                            try {
                                s3AsyncClient.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            s3AsyncClient.close();
                        }
                    }
                } finally {
                }
            } finally {
                asyncClient = null;
            }
        }
    }

    private static ThreadPoolExecutor createThreadPoolExecutor(Conf conf) {
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(50, conf.maxThreads, conf.threadsKeepAliveTime, TimeUnit.SECONDS, new LinkedBlockingQueue(conf.maxTasks), new ThreadFactoryBuilder().threadNamePrefix("spark-rapids-async-s3").build());
        threadPoolExecutor.allowCoreThreadTimeOut(true);
        return threadPoolExecutor;
    }

    private static NettyNioAsyncHttpClient.Builder nettyBuilder(Conf conf) {
        return NettyNioAsyncHttpClient.builder().maxConcurrency(Integer.valueOf(conf.maxConcurrency)).tcpKeepAlive(Boolean.valueOf(conf.connectionKeepAlive)).connectionTimeToLive(Duration.ofMinutes(conf.connectionTTL));
    }

    private static AwsCrtAsyncHttpClient.Builder crtBuilder(Conf conf) {
        AwsCrtAsyncHttpClient.Builder maxConcurrency = AwsCrtAsyncHttpClient.builder().maxConcurrency(Integer.valueOf(conf.maxConcurrency));
        if (conf.connectionKeepAlive) {
            maxConcurrency.tcpKeepAliveConfiguration(TcpKeepAliveConfiguration.builder().keepAliveInterval(Duration.ofMinutes(5L)).keepAliveTimeout(Duration.ofSeconds(30L)).build());
        }
        return maxConcurrency;
    }
}
