package io.ray.streaming.runtime.transfer;

import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.config.types.TransferChannelType;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.transfer.channel.ChannelUtils;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import io.ray.streaming.runtime.util.Platform;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/ray/streaming/runtime/transfer/DataWriter.class */
public class DataWriter {
    private static final Logger LOG = LoggerFactory.getLogger(DataWriter.class);
    private long nativeWriterPtr;
    private ByteBuffer buffer = ByteBuffer.allocateDirect(0);
    private long bufferAddress;
    private List<String> outputChannels;

    public DataWriter(List<String> list, List<BaseActorHandle> list2, Map<String, OffsetInfo> map, StreamingWorkerConfig streamingWorkerConfig) {
        ensureBuffer(0);
        Preconditions.checkArgument(!list.isEmpty());
        Preconditions.checkArgument(list.size() == list2.size());
        this.outputChannels = list;
        ChannelCreationParametersBuilder buildOutputQueueParameters = new ChannelCreationParametersBuilder().buildOutputQueueParameters(list, list2);
        byte[][] bArr = (byte[][]) list.stream().map(ChannelId::idStrToBytes).toArray(i -> {
            return new byte[i];
        });
        long channelSize = streamingWorkerConfig.transferConfig.channelSize();
        long[] jArr = new long[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            String str = list.get(i2);
            if (map.containsKey(str)) {
                jArr[i2] = map.get(str).getStreamingMsgId();
            } else {
                jArr[i2] = 0;
            }
        }
        this.nativeWriterPtr = createWriterNative(buildOutputQueueParameters, bArr, jArr, channelSize, ChannelUtils.toNativeConf(streamingWorkerConfig), TransferChannelType.MEMORY_CHANNEL == streamingWorkerConfig.transferConfig.channelType());
        LOG.info("Create DataWriter succeed for worker: {}.", streamingWorkerConfig.workerInternalConfig.workerName());
    }

    private static native long createWriterNative(ChannelCreationParametersBuilder channelCreationParametersBuilder, byte[][] bArr, long[] jArr, long j, byte[] bArr2, boolean z);

    public void write(ChannelId channelId, ByteBuffer byteBuffer) {
        int remaining = byteBuffer.remaining();
        ensureBuffer(remaining);
        this.buffer.clear();
        this.buffer.put(byteBuffer);
        writeMessageNative(this.nativeWriterPtr, channelId.getNativeIdPtr(), this.bufferAddress, remaining);
    }

    public void write(Set<ChannelId> set, ByteBuffer byteBuffer) {
        int remaining = byteBuffer.remaining();
        ensureBuffer(remaining);
        for (ChannelId channelId : set) {
            this.buffer.clear();
            this.buffer.put(byteBuffer.duplicate());
            writeMessageNative(this.nativeWriterPtr, channelId.getNativeIdPtr(), this.bufferAddress, remaining);
        }
    }

    private void ensureBuffer(int i) {
        if (this.buffer.capacity() < i) {
            this.buffer = ByteBuffer.allocateDirect(i);
            this.buffer.order(ByteOrder.nativeOrder());
            this.bufferAddress = Platform.getAddress(this.buffer);
        }
    }

    public Map<String, OffsetInfo> getOutputCheckpoints() {
        long[] outputMsgIdNative = getOutputMsgIdNative(this.nativeWriterPtr);
        HashMap hashMap = new HashMap(this.outputChannels.size());
        for (int i = 0; i < this.outputChannels.size(); i++) {
            hashMap.put(this.outputChannels.get(i), new OffsetInfo(outputMsgIdNative[i]));
        }
        LOG.info("got output points, {}.", hashMap);
        return hashMap;
    }

    public void broadcastBarrier(long j, ByteBuffer byteBuffer) {
        LOG.info("Broadcast barrier, cpId={}.", Long.valueOf(j));
        Preconditions.checkArgument(byteBuffer.order() == ByteOrder.nativeOrder());
        broadcastBarrierNative(this.nativeWriterPtr, j, byteBuffer.array());
    }

    public void clearCheckpoint(long j) {
        LOG.info("Producer clear checkpoint, checkpointId={}.", Long.valueOf(j));
        clearCheckpointNative(this.nativeWriterPtr, j);
    }

    public void stop() {
        stopWriterNative(this.nativeWriterPtr);
    }

    public void close() {
        if (this.nativeWriterPtr == 0) {
            return;
        }
        LOG.info("Closing data writer.");
        closeWriterNative(this.nativeWriterPtr);
        this.nativeWriterPtr = 0L;
        LOG.info("Finish closing data writer.");
    }

    private native long writeMessageNative(long j, long j2, long j3, int i);

    private native void stopWriterNative(long j);

    private native void closeWriterNative(long j);

    private native long[] getOutputMsgIdNative(long j);

    private native void broadcastBarrierNative(long j, long j2, byte[] bArr);

    private native void clearCheckpointNative(long j, long j2);
}
