package org.apache.celeborn.common.network.server;

import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.celeborn.common.network.server.memory.MemoryManager;
import org.apache.celeborn.shaded.io.netty.buffer.PooledByteBufAllocator;
import org.apache.celeborn.shaded.io.netty.channel.Channel;
import org.apache.celeborn.shaded.io.netty.channel.ChannelDuplexHandler;
import org.apache.celeborn.shaded.io.netty.channel.ChannelHandler;
import org.apache.celeborn.shaded.io.netty.channel.ChannelHandlerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ChannelHandler.Sharable
/* loaded from: input_file:org/apache/celeborn/common/network/server/ChannelsLimiter.class */
public class ChannelsLimiter extends ChannelDuplexHandler implements MemoryManager.MemoryPressureListener {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) ChannelsLimiter.class);
    private final String moduleName;
    private final Set<Channel> channels = ConcurrentHashMap.newKeySet();
    private final AtomicBoolean isPaused = new AtomicBoolean(false);

    /* loaded from: input_file:org/apache/celeborn/common/network/server/ChannelsLimiter$TrimCache.class */
    class TrimCache {
        TrimCache() {
        }
    }

    public ChannelsLimiter(String str) {
        this.moduleName = str;
        MemoryManager.instance().registerMemoryListener(this);
    }

    private void pauseAllChannels() {
        this.isPaused.set(true);
        this.channels.forEach(channel -> {
            if (channel.config().isAutoRead()) {
                channel.config().setAutoRead(false);
            }
        });
    }

    private void trimCache() {
        this.channels.forEach(channel -> {
            channel.pipeline().fireUserEventTriggered((Object) new TrimCache());
        });
    }

    private void resumeAllChannels() {
        synchronized (this.isPaused) {
            this.isPaused.set(false);
            this.channels.forEach(channel -> {
                if (channel.config().isAutoRead()) {
                    return;
                }
                channel.config().setAutoRead(true);
            });
        }
    }

    @Override // org.apache.celeborn.shaded.io.netty.channel.ChannelHandlerAdapter, org.apache.celeborn.shaded.io.netty.channel.ChannelHandler
    public void handlerAdded(ChannelHandlerContext channelHandlerContext) throws Exception {
        this.channels.add(channelHandlerContext.channel());
        synchronized (this.isPaused) {
            if (this.isPaused.get()) {
                channelHandlerContext.channel().config().setAutoRead(false);
            }
        }
        super.handlerAdded(channelHandlerContext);
    }

    @Override // org.apache.celeborn.shaded.io.netty.channel.ChannelHandlerAdapter, org.apache.celeborn.shaded.io.netty.channel.ChannelHandler
    public void handlerRemoved(ChannelHandlerContext channelHandlerContext) throws Exception {
        if (!channelHandlerContext.channel().config().isAutoRead()) {
            channelHandlerContext.channel().config().setAutoRead(true);
        }
        this.channels.remove(channelHandlerContext.channel());
        super.handlerRemoved(channelHandlerContext);
    }

    @Override // org.apache.celeborn.shaded.io.netty.channel.ChannelInboundHandlerAdapter, org.apache.celeborn.shaded.io.netty.channel.ChannelInboundHandler
    public void userEventTriggered(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (obj instanceof TrimCache) {
            ((PooledByteBufAllocator) channelHandlerContext.alloc()).trimCurrentThreadCache();
        }
    }

    @Override // org.apache.celeborn.common.network.server.memory.MemoryManager.MemoryPressureListener
    public void onPause(String str) {
        if (this.moduleName.equals(str)) {
            logger.info(this.moduleName + " channels pause read.");
            pauseAllChannels();
        }
    }

    @Override // org.apache.celeborn.common.network.server.memory.MemoryManager.MemoryPressureListener
    public void onResume(String str) {
        if (str.equalsIgnoreCase("all")) {
            logger.info(this.moduleName + " channels resume read.");
            resumeAllChannels();
        }
        if (this.moduleName.equals(str)) {
            logger.info(this.moduleName + " channels resume read.");
            resumeAllChannels();
        }
    }

    @Override // org.apache.celeborn.common.network.server.memory.MemoryManager.MemoryPressureListener
    public void onTrim() {
        trimCache();
    }
}
