package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.shuffle.ShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/spark/shuffle/celeborn/RssShuffleManager.class */
public class RssShuffleManager implements ShuffleManager {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) RssShuffleManager.class);
    private static final String sortShuffleManagerName = "org.apache.spark.shuffle.sort.SortShuffleManager";
    private final SparkConf conf;
    private final CelebornConf celebornConf;
    private final int cores;
    private String newAppId;
    private LifecycleManager lifecycleManager;
    private ShuffleClient rssShuffleClient;
    private volatile SortShuffleManager _sortShuffleManager;
    private final ConcurrentHashMap.KeySetView<Integer, Boolean> sortShuffleIds = ConcurrentHashMap.newKeySet();
    private final RssShuffleFallbackPolicyRunner fallbackPolicyRunner;

    public RssShuffleManager(SparkConf sparkConf) {
        this.conf = sparkConf;
        this.celebornConf = SparkUtils.fromSparkConf(sparkConf);
        this.cores = sparkConf.getInt("spark.executor.cores", 1);
        this.fallbackPolicyRunner = new RssShuffleFallbackPolicyRunner(this.celebornConf);
    }

    private boolean isDriver() {
        return "driver".equals(SparkEnv.get().executorId());
    }

    private SortShuffleManager sortShuffleManager() {
        if (this._sortShuffleManager == null) {
            synchronized (this) {
                if (this._sortShuffleManager == null) {
                    this._sortShuffleManager = (SortShuffleManager) SparkUtils.instantiateClass(sortShuffleManagerName, this.conf, Boolean.valueOf(isDriver()));
                }
            }
        }
        return this._sortShuffleManager;
    }

    private void initializeLifecycleManager(String str) {
        if (isDriver() && this.lifecycleManager == null) {
            synchronized (this) {
                if (this.lifecycleManager == null) {
                    this.lifecycleManager = new LifecycleManager(str, this.celebornConf);
                    this.rssShuffleClient = ShuffleClient.get(this.lifecycleManager.self(), this.celebornConf, this.lifecycleManager.getUserIdentifier());
                }
            }
        }
    }

    public <K, V, C> ShuffleHandle registerShuffle(int i, ShuffleDependency<K, V, C> shuffleDependency) {
        this.newAppId = SparkUtils.genNewAppId(shuffleDependency.rdd().context());
        initializeLifecycleManager(this.newAppId);
        if (!this.fallbackPolicyRunner.applyAllFallbackPolicy(this.lifecycleManager, shuffleDependency.partitioner().numPartitions())) {
            return new RssShuffleHandle(this.newAppId, this.lifecycleManager.getRssMetaServiceHost(), this.lifecycleManager.getRssMetaServicePort(), this.lifecycleManager.getUserIdentifier(), i, shuffleDependency.rdd().getNumPartitions(), shuffleDependency);
        }
        logger.warn("Fallback to SortShuffleManager!");
        this.sortShuffleIds.add(Integer.valueOf(i));
        return sortShuffleManager().registerShuffle(i, shuffleDependency);
    }

    public boolean unregisterShuffle(int i) {
        if (this.sortShuffleIds.contains(Integer.valueOf(i))) {
            return sortShuffleManager().unregisterShuffle(i);
        }
        if (this.newAppId == null) {
            return true;
        }
        if (this.rssShuffleClient == null) {
            return false;
        }
        return this.rssShuffleClient.unregisterShuffle(this.newAppId, i, isDriver());
    }

    public ShuffleBlockResolver shuffleBlockResolver() {
        return sortShuffleManager().shuffleBlockResolver();
    }

    public void stop() {
        if (this.rssShuffleClient != null) {
            this.rssShuffleClient.shutdown();
        }
        if (this.lifecycleManager != null) {
            this.lifecycleManager.stop();
        }
        if (sortShuffleManager() != null) {
            sortShuffleManager().stop();
        }
    }

    public <K, V> ShuffleWriter<K, V> getWriter(ShuffleHandle shuffleHandle, long j, TaskContext taskContext, ShuffleWriteMetricsReporter shuffleWriteMetricsReporter) {
        try {
            if (!(shuffleHandle instanceof RssShuffleHandle)) {
                this.sortShuffleIds.add(Integer.valueOf(shuffleHandle.shuffleId()));
                return sortShuffleManager().getWriter(shuffleHandle, j, taskContext, shuffleWriteMetricsReporter);
            }
            RssShuffleHandle rssShuffleHandle = (RssShuffleHandle) shuffleHandle;
            ShuffleClient shuffleClient = ShuffleClient.get(rssShuffleHandle.rssMetaServiceHost(), rssShuffleHandle.rssMetaServicePort(), this.celebornConf, rssShuffleHandle.userIdentifier());
            if (ShuffleMode.SORT.equals(this.celebornConf.shuffleWriterMode())) {
                return new SortBasedShuffleWriter(rssShuffleHandle.dependency(), rssShuffleHandle.newAppId(), rssShuffleHandle.numMappers(), taskContext, this.celebornConf, shuffleClient, shuffleWriteMetricsReporter);
            }
            if (ShuffleMode.HASH.equals(this.celebornConf.shuffleWriterMode())) {
                return new HashBasedShuffleWriter(rssShuffleHandle, taskContext, this.celebornConf, shuffleClient, shuffleWriteMetricsReporter, SendBufferPool.get(this.cores));
            }
            throw new UnsupportedOperationException("Unrecognized shuffle write mode!" + this.celebornConf.shuffleWriterMode());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle shuffleHandle, int i, int i2, int i3, int i4, TaskContext taskContext, ShuffleReadMetricsReporter shuffleReadMetricsReporter) {
        return shuffleHandle instanceof RssShuffleHandle ? new RssShuffleReader((RssShuffleHandle) shuffleHandle, i3, i4, i, i2, taskContext, this.celebornConf, shuffleReadMetricsReporter) : SparkUtils.getReader(sortShuffleManager(), shuffleHandle, Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(i4), taskContext, shuffleReadMetricsReporter);
    }

    public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle shuffleHandle, int i, int i2, TaskContext taskContext, ShuffleReadMetricsReporter shuffleReadMetricsReporter) {
        return shuffleHandle instanceof RssShuffleHandle ? new RssShuffleReader((RssShuffleHandle) shuffleHandle, i, i2, 0, Integer.MAX_VALUE, taskContext, this.celebornConf, shuffleReadMetricsReporter) : SparkUtils.getReader(sortShuffleManager(), shuffleHandle, 0, Integer.MAX_VALUE, Integer.valueOf(i), Integer.valueOf(i2), taskContext, shuffleReadMetricsReporter);
    }

    public <K, C> ShuffleReader<K, C> getReaderForRange(ShuffleHandle shuffleHandle, int i, int i2, int i3, int i4, TaskContext taskContext, ShuffleReadMetricsReporter shuffleReadMetricsReporter) {
        return shuffleHandle instanceof RssShuffleHandle ? new RssShuffleReader((RssShuffleHandle) shuffleHandle, i3, i4, i, i2, taskContext, this.celebornConf, shuffleReadMetricsReporter) : SparkUtils.getReader(sortShuffleManager(), shuffleHandle, Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(i4), taskContext, shuffleReadMetricsReporter);
    }
}
