package org.apache.celeborn.client.write;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.shaded.io.netty.channel.ChannelFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/celeborn/client/write/PushState.class */
public class PushState {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) PushState.class);
    private int pushBufferMaxSize;
    private long pushTimeout;
    public final AtomicInteger batchId = new AtomicInteger();
    private final ConcurrentHashMap<Integer, BatchInfo> inflightBatchInfos = new ConcurrentHashMap<>();
    public AtomicReference<IOException> exception = new AtomicReference<>();
    public final ConcurrentHashMap<String, DataBatches> batchesMap = new ConcurrentHashMap<>();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/celeborn/client/write/PushState$BatchInfo.class */
    public class BatchInfo {
        ChannelFuture channelFuture;
        long pushTime = -1;
        RpcResponseCallback callback;

        BatchInfo() {
        }
    }

    public PushState(CelebornConf celebornConf) {
        this.pushBufferMaxSize = celebornConf.pushBufferMaxSize();
        this.pushTimeout = celebornConf.pushDataTimeoutMs();
    }

    public void addBatch(int i) {
        this.inflightBatchInfos.computeIfAbsent(Integer.valueOf(i), num -> {
            return new BatchInfo();
        });
    }

    public void removeBatch(int i) {
        BatchInfo remove = this.inflightBatchInfos.remove(Integer.valueOf(i));
        if (remove == null || remove.channelFuture == null) {
            return;
        }
        remove.channelFuture.cancel(true);
    }

    public int inflightBatchCount() {
        return this.inflightBatchInfos.size();
    }

    public synchronized void failExpiredBatch() {
        long currentTimeMillis = System.currentTimeMillis();
        this.inflightBatchInfos.values().forEach(batchInfo -> {
            if (batchInfo.pushTime == -1 || currentTimeMillis - batchInfo.pushTime <= this.pushTimeout || batchInfo.callback == null) {
                return;
            }
            batchInfo.channelFuture.cancel(true);
            batchInfo.callback.onFailure(new IOException(StatusCode.PUSH_DATA_TIMEOUT.getMessage()));
            batchInfo.channelFuture = null;
            batchInfo.callback = null;
        });
    }

    public void pushStarted(int i, ChannelFuture channelFuture, RpcResponseCallback rpcResponseCallback) {
        BatchInfo batchInfo = this.inflightBatchInfos.get(Integer.valueOf(i));
        if (batchInfo != null) {
            batchInfo.pushTime = System.currentTimeMillis();
            batchInfo.channelFuture = channelFuture;
            batchInfo.callback = rpcResponseCallback;
        }
    }

    public void cleanup() {
        if (this.inflightBatchInfos.isEmpty()) {
            return;
        }
        logger.debug("Cancel all {} futures.", Integer.valueOf(this.inflightBatchInfos.size()));
        this.inflightBatchInfos.values().forEach(batchInfo -> {
            if (batchInfo.channelFuture != null) {
                batchInfo.channelFuture.cancel(true);
            }
        });
        this.inflightBatchInfos.clear();
    }

    public boolean addBatchData(String str, PartitionLocation partitionLocation, int i, byte[] bArr) {
        DataBatches computeIfAbsent = this.batchesMap.computeIfAbsent(str, str2 -> {
            return new DataBatches();
        });
        computeIfAbsent.addDataBatch(partitionLocation, i, bArr);
        return computeIfAbsent.getTotalSize() > this.pushBufferMaxSize;
    }

    public DataBatches takeDataBatches(String str) {
        return this.batchesMap.remove(str);
    }
}
