package org.apache.spark.sql.rapids;

import ai.rapids.cudf.ContiguousTable;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;
import com.nvidia.spark.rapids.DegenerateRapidsBuffer;
import com.nvidia.spark.rapids.GpuCompressedColumnVector;
import com.nvidia.spark.rapids.GpuPackedTableColumn;
import com.nvidia.spark.rapids.MetaUtils$;
import com.nvidia.spark.rapids.RapidsDeviceMemoryStore;
import com.nvidia.spark.rapids.ShuffleBufferCatalog;
import com.nvidia.spark.rapids.ShuffleBufferId;
import com.nvidia.spark.rapids.SpillPriorities$;
import com.nvidia.spark.rapids.format.TableMeta;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleServer;
import com.nvidia.spark.rapids.shuffle.RapidsShuffleTransport$;
import org.apache.spark.internal.Logging;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerId$;
import org.apache.spark.storage.ShuffleBlockId;
import org.slf4j.Logger;
import scala.Function0;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product2;
import scala.Some;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LongRef;

/* compiled from: RapidsShuffleInternalManagerBase.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ue\u0001B\u000b\u0017\u0001\u0005B\u0001\u0002\u0011\u0001\u0003\u0002\u0003\u0006I!\u0011\u0005\t\u000f\u0002\u0011\t\u0011)A\u0005\u0011\"AA\n\u0001B\u0001B\u0003%Q\n\u0003\u0005Q\u0001\t\u0005\t\u0015!\u0003R\u0011!!\u0006A!A!\u0002\u0013)\u0006\u0002C0\u0001\u0005\u0003\u0005\u000b\u0011\u00021\t\u0011\r\u0004!\u0011!Q\u0001\n\u0011D\u0001\u0002\u001c\u0001\u0003\u0002\u0003\u0006I!\u001c\u0005\b\u0003\u000f\u0001A\u0011AA\u0005\u0011%\ti\u0002\u0001b\u0001\n\u0013\ty\u0002\u0003\u0005\u0002(\u0001\u0001\u000b\u0011BA\u0011\u0011%\tI\u0003\u0001b\u0001\n\u0013\tY\u0003\u0003\u0005\u00024\u0001\u0001\u000b\u0011BA\u0017\u0011%\t)\u0004\u0001b\u0001\n\u0013\t9\u0004\u0003\u0005\u0002P\u0001\u0001\u000b\u0011BA\u001d\u0011%\t\t\u0006\u0001b\u0001\n\u0013\t\u0019\u0006C\u0004\u0002V\u0001\u0001\u000b\u0011B>\t\u000f\u0005]\u0003\u0001\"\u0011\u0002Z!9\u0011Q\u0010\u0001\u0005\n\u0005}\u0004bBAA\u0001\u0011\u0005\u00131\u0011\u0002\u0014%\u0006\u0004\u0018\u000eZ:DC\u000eD\u0017N\\4Xe&$XM\u001d\u0006\u0003/a\taA]1qS\u0012\u001c(BA\r\u001b\u0003\r\u0019\u0018\u000f\u001c\u0006\u00037q\tQa\u001d9be.T!!\b\u0010\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005y\u0012aA8sO\u000e\u0001Qc\u0001\u0012,qM\u0019\u0001a\t\u001e\u0011\t\u0011:\u0013fN\u0007\u0002K)\u0011aEG\u0001\bg\",hM\u001a7f\u0013\tASEA\u0007TQV4g\r\\3Xe&$XM\u001d\t\u0003U-b\u0001\u0001B\u0003-\u0001\t\u0007QFA\u0001L#\tqC\u0007\u0005\u00020e5\t\u0001GC\u00012\u0003\u0015\u00198-\u00197b\u0013\t\u0019\u0004GA\u0004O_RD\u0017N\\4\u0011\u0005=*\u0014B\u0001\u001c1\u0005\r\te.\u001f\t\u0003Ua\"Q!\u000f\u0001C\u00025\u0012\u0011A\u0016\t\u0003wyj\u0011\u0001\u0010\u0006\u0003{i\t\u0001\"\u001b8uKJt\u0017\r\\\u0005\u0003\u007fq\u0012q\u0001T8hO&tw-\u0001\u0007cY>\u001c7.T1oC\u001e,'\u000f\u0005\u0002C\u000b6\t1I\u0003\u0002E5\u000591\u000f^8sC\u001e,\u0017B\u0001$D\u00051\u0011En\\2l\u001b\u0006t\u0017mZ3s\u0003\u0019A\u0017M\u001c3mKB!\u0011JS\u00158\u001b\u00051\u0012BA&\u0017\u0005A9\u0005/^*ik\u001a4G.\u001a%b]\u0012dW-A\u0003nCBLE\r\u0005\u00020\u001d&\u0011q\n\r\u0002\u0005\u0019>tw-A\bnKR\u0014\u0018nY:SKB|'\u000f^3s!\t!#+\u0003\u0002TK\tY2\u000b[;gM2,wK]5uK6+GO]5dgJ+\u0007o\u001c:uKJ\fqaY1uC2|w\r\u0005\u0002W;6\tqK\u0003\u0002\u00181*\u00111$\u0017\u0006\u00035n\u000baA\u001c<jI&\f'\"\u0001/\u0002\u0007\r|W.\u0003\u0002_/\n!2\u000b[;gM2,')\u001e4gKJ\u001c\u0015\r^1m_\u001e\fab\u001d5vM\u001adWm\u0015;pe\u0006<W\r\u0005\u0002WC&\u0011!m\u0016\u0002\u0018%\u0006\u0004\u0018\u000eZ:EKZL7-Z'f[>\u0014\u0018p\u0015;pe\u0016\f1C]1qS\u0012\u001c8\u000b[;gM2,7+\u001a:wKJ\u00042aL3h\u0013\t1\u0007G\u0001\u0004PaRLwN\u001c\t\u0003Q*l\u0011!\u001b\u0006\u0003M]K!a[5\u0003'I\u000b\u0007/\u001b3t'\",hM\u001a7f'\u0016\u0014h/\u001a:\u0002\u000f5,GO]5dgB!a.\u001e=|\u001d\ty7\u000f\u0005\u0002qa5\t\u0011O\u0003\u0002sA\u00051AH]8pizJ!\u0001\u001e\u0019\u0002\rA\u0013X\rZ3g\u0013\t1xOA\u0002NCBT!\u0001\u001e\u0019\u0011\u00059L\u0018B\u0001>x\u0005\u0019\u0019FO]5oOB\u0019A0a\u0001\u000e\u0003uT!A`@\u0002\r5,GO]5d\u0015\r\t\t\u0001G\u0001\nKb,7-\u001e;j_:L1!!\u0002~\u0005%\u0019\u0016\u000bT'fiJL7-\u0001\u0004=S:LGO\u0010\u000b\u0013\u0003\u0017\ti!a\u0004\u0002\u0012\u0005M\u0011QCA\f\u00033\tY\u0002\u0005\u0003J\u0001%:\u0004\"\u0002!\n\u0001\u0004\t\u0005\"B$\n\u0001\u0004A\u0005\"\u0002'\n\u0001\u0004i\u0005\"\u0002)\n\u0001\u0004\t\u0006\"\u0002+\n\u0001\u0004)\u0006\"B0\n\u0001\u0004\u0001\u0007\"B2\n\u0001\u0004!\u0007\"\u00027\n\u0001\u0004i\u0017\u0001\u00038v[B\u000b'\u000f^:\u0016\u0005\u0005\u0005\u0002cA\u0018\u0002$%\u0019\u0011Q\u0005\u0019\u0003\u0007%sG/A\u0005ok6\u0004\u0016M\u001d;tA\u0005)1/\u001b>fgV\u0011\u0011Q\u0006\t\u0005_\u0005=R*C\u0002\u00022A\u0012Q!\u0011:sCf\faa]5{KN\u0004\u0013\u0001E<sSR$XM\u001c\"vM\u001a,'/\u00133t+\t\tI\u0004\u0005\u0004\u0002<\u0005\u0015\u0013\u0011J\u0007\u0003\u0003{QA!a\u0010\u0002B\u00059Q.\u001e;bE2,'bAA\"a\u0005Q1m\u001c7mK\u000e$\u0018n\u001c8\n\t\u0005\u001d\u0013Q\b\u0002\f\u0003J\u0014\u0018-\u001f\"vM\u001a,'\u000fE\u0002W\u0003\u0017J1!!\u0014X\u0005=\u0019\u0006.\u001e4gY\u0016\u0014UO\u001a4fe&#\u0017!E<sSR$XM\u001c\"vM\u001a,'/\u00133tA\u0005\u0011RO\\2p[B\u0014Xm]:fI6+GO]5d+\u0005Y\u0018aE;oG>l\u0007O]3tg\u0016$W*\u001a;sS\u000e\u0004\u0013!B<sSR,G\u0003BA.\u0003C\u00022aLA/\u0013\r\ty\u0006\r\u0002\u0005+:LG\u000fC\u0004\u0002dI\u0001\r!!\u001a\u0002\u000fI,7m\u001c:egB1\u0011qMA9\u0003orA!!\u001b\u0002n9\u0019\u0001/a\u001b\n\u0003EJ1!a\u001c1\u0003\u001d\u0001\u0018mY6bO\u0016LA!a\u001d\u0002v\tA\u0011\n^3sCR|'OC\u0002\u0002pA\u0002RaLA=S]J1!a\u001f1\u0005!\u0001&o\u001c3vGR\u0014\u0014\u0001D2mK\u0006t7\u000b^8sC\u001e,GCAA.\u0003\u0011\u0019Ho\u001c9\u0015\t\u0005\u0015\u00151\u0013\t\u0005_\u0015\f9\t\u0005\u0003\u0002\n\u0006=UBAAF\u0015\r\tiIG\u0001\ng\u000eDW\rZ;mKJLA!!%\u0002\f\nIQ*\u00199Ti\u0006$Xo\u001d\u0005\b\u0003+#\u0002\u0019AAL\u0003\u001d\u0019XoY2fgN\u00042aLAM\u0013\r\tY\n\r\u0002\b\u0005>|G.Z1o\u0001")
/* loaded from: input_file:org/apache/spark/sql/rapids/RapidsCachingWriter.class */
public class RapidsCachingWriter<K, V> extends ShuffleWriter<K, V> implements Logging {
    private final BlockManager blockManager;
    private final GpuShuffleHandle<K, V> handle;
    private final long mapId;
    private final ShuffleWriteMetricsReporter metricsReporter;
    private final ShuffleBufferCatalog catalog;
    private final RapidsDeviceMemoryStore shuffleStorage;
    private final Option<RapidsShuffleServer> rapidsShuffleServer;
    private final int numParts;
    private final long[] sizes;
    private final ArrayBuffer<ShuffleBufferId> writtenBufferIds;
    private final SQLMetric uncompressedMetric;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    private int numParts() {
        return this.numParts;
    }

    private long[] sizes() {
        return this.sizes;
    }

    private ArrayBuffer<ShuffleBufferId> writtenBufferIds() {
        return this.writtenBufferIds;
    }

    private SQLMetric uncompressedMetric() {
        return this.uncompressedMetric;
    }

    public void write(Iterator<Product2<K, V>> iterator) {
        NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.write", NvtxColor.CYAN);
        try {
            LongRef create = LongRef.create(0L);
            LongRef create2 = LongRef.create(0L);
            iterator.foreach(product2 -> {
                $anonfun$write$1(this, create2, create, product2);
                return BoxedUnit.UNIT;
            });
            this.metricsReporter.incBytesWritten(create.elem);
            this.metricsReporter.incRecordsWritten(create2.elem);
        } finally {
            nvtxRange.close();
        }
    }

    private void cleanStorage() {
        writtenBufferIds().foreach(shuffleBufferId -> {
            $anonfun$cleanStorage$1(this, shuffleBufferId);
            return BoxedUnit.UNIT;
        });
    }

    public Option<MapStatus> stop(boolean z) {
        BlockManagerId shuffleServerId;
        None$ some;
        NvtxRange nvtxRange = new NvtxRange("RapidsCachingWriter.close", NvtxColor.CYAN);
        try {
            if (z) {
                if (this.rapidsShuffleServer.isDefined()) {
                    BlockManagerId originalShuffleServerId = ((RapidsShuffleServer) this.rapidsShuffleServer.get()).originalShuffleServerId();
                    shuffleServerId = BlockManagerId$.MODULE$.apply(originalShuffleServerId.executorId(), originalShuffleServerId.host(), originalShuffleServerId.port(), new Some(new StringBuilder(1).append(RapidsShuffleTransport$.MODULE$.BLOCK_MANAGER_ID_TOPO_PREFIX()).append("=").append(((RapidsShuffleServer) this.rapidsShuffleServer.get()).getPort()).toString()));
                } else {
                    shuffleServerId = this.blockManager.shuffleServerId();
                }
                BlockManagerId blockManagerId = shuffleServerId;
                logInfo(() -> {
                    return new StringBuilder(43).append("Done caching shuffle success=").append(z).append(", server_id=").append(blockManagerId).append(", ").append(new StringBuilder(15).append("map_id=").append(this.mapId).append(", sizes=").append(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(this.sizes())).mkString(",")).toString()).toString();
                });
                some = new Some(MapStatus$.MODULE$.apply(blockManagerId, sizes(), this.mapId));
            } else {
                cleanStorage();
                some = None$.MODULE$;
            }
            return some;
        } finally {
            nvtxRange.close();
        }
    }

    public static final /* synthetic */ void $anonfun$write$1(RapidsCachingWriter rapidsCachingWriter, LongRef longRef, LongRef longRef2, Product2 product2) {
        long length;
        int unboxToInt = BoxesRunTime.unboxToInt(product2._1());
        ColumnarBatch columnarBatch = (ColumnarBatch) product2._2();
        rapidsCachingWriter.logDebug(() -> {
            return new StringBuilder(38).append("Caching shuffle_id=").append(rapidsCachingWriter.handle.shuffleId()).append(" map_id=").append(rapidsCachingWriter.mapId).append(", partId=").append(unboxToInt).append(", ").append(new StringBuilder(28).append("batch=[num_cols=").append(columnarBatch.numCols()).append(", num_rows=").append(columnarBatch.numRows()).append("]").toString()).toString();
        });
        longRef.elem += columnarBatch.numRows();
        ShuffleBufferId nextShuffleBufferId = rapidsCachingWriter.catalog.nextShuffleBufferId(new ShuffleBlockId(rapidsCachingWriter.handle.shuffleId(), rapidsCachingWriter.mapId, unboxToInt));
        if (columnarBatch.numRows() <= 0 || columnarBatch.numCols() <= 0) {
            rapidsCachingWriter.catalog.registerNewBuffer(new DegenerateRapidsBuffer(nextShuffleBufferId, MetaUtils$.MODULE$.buildDegenerateTableMeta(columnarBatch)));
            if (columnarBatch.numRows() > 0) {
                rapidsCachingWriter.sizes()[unboxToInt] = rapidsCachingWriter.sizes()[unboxToInt] + 100;
            }
        } else {
            ColumnVector column = columnarBatch.column(0);
            if (column instanceof GpuPackedTableColumn) {
                GpuPackedTableColumn gpuPackedTableColumn = (GpuPackedTableColumn) column;
                ContiguousTable contiguousTable = gpuPackedTableColumn.getContiguousTable();
                length = gpuPackedTableColumn.getTableBuffer().getLength();
                rapidsCachingWriter.uncompressedMetric().$plus$eq(length);
                rapidsCachingWriter.shuffleStorage.addContiguousTable(nextShuffleBufferId, contiguousTable, SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY(), rapidsCachingWriter.shuffleStorage.addContiguousTable$default$4());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!(column instanceof GpuCompressedColumnVector)) {
                    throw new IllegalStateException(new StringBuilder(24).append("Unexpected column type: ").append(column.getClass()).toString());
                }
                GpuCompressedColumnVector gpuCompressedColumnVector = (GpuCompressedColumnVector) column;
                DeviceMemoryBuffer tableBuffer = gpuCompressedColumnVector.getTableBuffer();
                tableBuffer.incRefCount();
                length = tableBuffer.getLength();
                TableMeta tableMeta = gpuCompressedColumnVector.getTableMeta();
                tableMeta.bufferMeta().mutateId(nextShuffleBufferId.tableId());
                rapidsCachingWriter.uncompressedMetric().$plus$eq(tableMeta.bufferMeta().uncompressedSize());
                rapidsCachingWriter.shuffleStorage.addBuffer(nextShuffleBufferId, tableBuffer, tableMeta, SpillPriorities$.MODULE$.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY(), rapidsCachingWriter.shuffleStorage.addBuffer$default$5());
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            longRef2.elem += length;
            rapidsCachingWriter.sizes()[unboxToInt] = rapidsCachingWriter.sizes()[unboxToInt] + length;
        }
        rapidsCachingWriter.writtenBufferIds().append(Predef$.MODULE$.wrapRefArray(new ShuffleBufferId[]{nextShuffleBufferId}));
    }

    public static final /* synthetic */ void $anonfun$cleanStorage$1(RapidsCachingWriter rapidsCachingWriter, ShuffleBufferId shuffleBufferId) {
        rapidsCachingWriter.catalog.removeBuffer(shuffleBufferId);
    }

    public RapidsCachingWriter(BlockManager blockManager, GpuShuffleHandle<K, V> gpuShuffleHandle, long j, ShuffleWriteMetricsReporter shuffleWriteMetricsReporter, ShuffleBufferCatalog shuffleBufferCatalog, RapidsDeviceMemoryStore rapidsDeviceMemoryStore, Option<RapidsShuffleServer> option, Map<String, SQLMetric> map) {
        this.blockManager = blockManager;
        this.handle = gpuShuffleHandle;
        this.mapId = j;
        this.metricsReporter = shuffleWriteMetricsReporter;
        this.catalog = shuffleBufferCatalog;
        this.shuffleStorage = rapidsDeviceMemoryStore;
        this.rapidsShuffleServer = option;
        Logging.$init$(this);
        this.numParts = gpuShuffleHandle.m1151dependency().partitioner().numPartitions();
        this.sizes = new long[numParts()];
        this.writtenBufferIds = new ArrayBuffer<>(numParts());
        this.uncompressedMetric = (SQLMetric) map.apply("dataSize");
    }
}
