package com.uber.rss.handlers;

import com.uber.rss.common.AppShufflePartitionId;
import com.uber.rss.common.FilePathAndLength;
import com.uber.rss.common.MapTaskCommitStatus;
import com.uber.rss.exceptions.RssInvalidDataException;
import com.uber.rss.execution.ShuffleExecutor;
import com.uber.rss.messages.BaseMessage;
import com.uber.rss.messages.ConnectDownloadRequest;
import com.uber.rss.messages.ConnectDownloadResponse;
import com.uber.rss.messages.GetDataAvailabilityRequest;
import com.uber.rss.messages.GetDataAvailabilityResponse;
import com.uber.rss.messages.ShuffleStageStatus;
import com.uber.rss.metrics.M3Stats;
import com.uber.rss.util.NettyUtils;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rss_shaded.com.uber.m3.tally.Counter;
import rss_shaded.com.uber.m3.tally.Gauge;

/* loaded from: input_file:com/uber/rss/handlers/DownloadChannelInboundHandler.class */
public class DownloadChannelInboundHandler extends ChannelInboundHandlerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(DownloadChannelInboundHandler.class);
    private static Counter numChannelActive = M3Stats.getDefaultScope().counter("numDownloadChannelActive");
    private static Counter numChannelInactive = M3Stats.getDefaultScope().counter("numDownloadChannelInactive");
    private static AtomicInteger concurrentChannelsAtomicInteger = new AtomicInteger();
    private static Gauge numConcurrentChannels = M3Stats.getDefaultScope().gauge("numConcurrentDownloadChannels");
    private static Counter closedIdleDownloadChannels = M3Stats.getDefaultScope().counter("closedIdleDownloadChannels");
    private final String serverId;
    private final long idleTimeoutMillis;
    private final DownloadServerHandler downloadServerHandler;
    private String connectionInfo = "";
    private AppShufflePartitionId appShufflePartitionId = null;
    private List<Long> fetchTaskAttemptIds = new ArrayList();
    private ChannelIdleCheck idleCheck;

    public DownloadChannelInboundHandler(String str, long j, ShuffleExecutor shuffleExecutor) {
        this.serverId = str;
        this.idleTimeoutMillis = j;
        this.downloadServerHandler = new DownloadServerHandler(shuffleExecutor);
    }

    public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
        super.channelActive(channelHandlerContext);
        processChannelActive(channelHandlerContext);
    }

    public void processChannelActive(ChannelHandlerContext channelHandlerContext) {
        numChannelActive.inc(1L);
        numConcurrentChannels.update(concurrentChannelsAtomicInteger.incrementAndGet());
        this.connectionInfo = NettyUtils.getServerConnectionInfo(channelHandlerContext);
        logger.debug("Channel active: {}", this.connectionInfo);
        this.idleCheck = new ChannelIdleCheck(channelHandlerContext, this.idleTimeoutMillis, closedIdleDownloadChannels);
        this.idleCheck.schedule();
    }

    public void channelInactive(ChannelHandlerContext channelHandlerContext) throws Exception {
        super.channelInactive(channelHandlerContext);
        numChannelInactive.inc(1L);
        numConcurrentChannels.update(concurrentChannelsAtomicInteger.decrementAndGet());
        logger.debug("Channel inactive: {}", this.connectionInfo);
        if (this.idleCheck != null) {
            this.idleCheck.cancel();
        }
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
        try {
            logger.debug("Got incoming message: {}, {}", obj, this.connectionInfo);
            if (this.idleCheck != null) {
                this.idleCheck.updateLastReadTime();
            }
            if (obj instanceof ConnectDownloadRequest) {
                logger.info("ConnectDownloadRequest: {}, {}", obj, this.connectionInfo);
                ConnectDownloadRequest connectDownloadRequest = (ConnectDownloadRequest) obj;
                this.appShufflePartitionId = new AppShufflePartitionId(connectDownloadRequest.getAppId(), connectDownloadRequest.getAppAttempt(), connectDownloadRequest.getShuffleId(), connectDownloadRequest.getPartitionId());
                this.fetchTaskAttemptIds = connectDownloadRequest.getTaskAttemptIds();
                ShuffleStageStatus shuffleStageStatus = this.downloadServerHandler.getShuffleStageStatus(this.appShufflePartitionId.getAppShuffleId());
                if (shuffleStageStatus.getFileStatus() == 1) {
                    logger.warn(String.format("Shuffle stage not started for %s, %s", this.appShufflePartitionId.getAppShuffleId(), this.connectionInfo));
                    HandlerUtil.writeResponseStatus(channelHandlerContext, (byte) 44);
                    ReferenceCountUtil.release(obj);
                } else {
                    this.downloadServerHandler.initialize(connectDownloadRequest);
                    MapTaskCommitStatus mapTaskCommitStatus = shuffleStageStatus.getMapTaskCommitStatus();
                    boolean z = mapTaskCommitStatus != null && mapTaskCommitStatus.isPartitionDataAvailable(this.fetchTaskAttemptIds);
                    ConnectDownloadResponse connectDownloadResponse = new ConnectDownloadResponse(this.serverId, "", mapTaskCommitStatus, z);
                    logger.info("ConnectDownloadResponse: {}, {}", connectDownloadResponse, this.connectionInfo);
                    sendResponseAndFiles(channelHandlerContext, z, shuffleStageStatus, connectDownloadResponse, this.idleCheck);
                }
            } else {
                if (!(obj instanceof GetDataAvailabilityRequest)) {
                    throw new RssInvalidDataException(String.format("Unsupported message: %s, %s", obj, this.connectionInfo));
                }
                logger.info("GetDataAvailabilityRequest: {}, {}", obj, this.connectionInfo);
                ShuffleStageStatus shuffleStageStatus2 = this.downloadServerHandler.getShuffleStageStatus(this.appShufflePartitionId.getAppShuffleId());
                MapTaskCommitStatus mapTaskCommitStatus2 = shuffleStageStatus2.getMapTaskCommitStatus();
                boolean z2 = mapTaskCommitStatus2 != null && mapTaskCommitStatus2.isPartitionDataAvailable(this.fetchTaskAttemptIds);
                sendResponseAndFiles(channelHandlerContext, z2, shuffleStageStatus2, new GetDataAvailabilityResponse(mapTaskCommitStatus2, z2), this.idleCheck);
            }
        } finally {
            ReferenceCountUtil.release(obj);
        }
    }

    public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
        M3Stats.addException(th, "serverHandler");
        logger.warn("Got exception " + this.connectionInfo, th);
        channelHandlerContext.close();
    }

    private void sendResponseAndFiles(ChannelHandlerContext channelHandlerContext, boolean z, ShuffleStageStatus shuffleStageStatus, BaseMessage baseMessage, ChannelIdleCheck channelIdleCheck) {
        byte transformToMessageResponseStatus = shuffleStageStatus.transformToMessageResponseStatus();
        if (!z) {
            ChannelFuture writeResponseMsg = HandlerUtil.writeResponseMsg(channelHandlerContext, transformToMessageResponseStatus, baseMessage, true);
            if (shuffleStageStatus.getFileStatus() == 2) {
                logger.warn("Partition file corrupted, partition {}, {}", this.appShufflePartitionId, this.connectionInfo);
                writeResponseMsg.addListener(new ChannelFutureCloseListener(this.connectionInfo));
                return;
            }
            return;
        }
        this.downloadServerHandler.finishShuffleStage(this.appShufflePartitionId.getAppShuffleId());
        List<FilePathAndLength> nonEmptyPartitionFiles = this.downloadServerHandler.getNonEmptyPartitionFiles(this.connectionInfo);
        ChannelFuture writeResponseMsg2 = HandlerUtil.writeResponseMsg(channelHandlerContext, transformToMessageResponseStatus, baseMessage, true);
        if (shuffleStageStatus.getFileStatus() == 2) {
            logger.warn("Partition file corrupted, partition {}, {}", this.appShufflePartitionId, this.connectionInfo);
            writeResponseMsg2.addListener(new ChannelFutureCloseListener(this.connectionInfo));
            return;
        }
        long sum = nonEmptyPartitionFiles.stream().mapToLong(filePathAndLength -> {
            return filePathAndLength.getLength();
        }).sum();
        ByteBuf buffer = channelHandlerContext.alloc().buffer(8);
        buffer.writeLong(sum);
        ChannelFuture writeAndFlush = channelHandlerContext.writeAndFlush(buffer);
        if (nonEmptyPartitionFiles.isEmpty()) {
            logger.warn("No partition file, partition {}, {}", this.appShufflePartitionId, this.connectionInfo);
            writeAndFlush.addListener(new ChannelFutureCloseListener(this.connectionInfo));
        } else if (this.downloadServerHandler.sendFiles(channelHandlerContext, nonEmptyPartitionFiles, channelIdleCheck) == null) {
            logger.warn("No file sent out, closing the connection, partition {}, {}", this.appShufflePartitionId, this.connectionInfo);
            writeAndFlush.addListener(new ChannelFutureCloseListener(this.connectionInfo));
        }
    }
}
