package org.neo4j.storageengine.api.enrichment;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Supplier;
import org.assertj.core.api.AbstractByteArrayAssert;
import org.assertj.core.api.Assertions;
import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.factory.list.MutableListFactory;
import org.eclipse.collections.api.list.MutableList;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.neo4j.internal.helpers.collection.Pair;
import org.neo4j.internal.kernel.api.connectioninfo.ClientConnectionInfo;
import org.neo4j.internal.kernel.api.security.AuthSubject;
import org.neo4j.internal.kernel.api.security.SecurityContext;
import org.neo4j.memory.EmptyMemoryTracker;
import org.neo4j.memory.MemoryTracker;
import org.neo4j.storageengine.api.enrichment.Enrichment;
import org.neo4j.test.Race;
import org.neo4j.test.RandomSupport;
import org.neo4j.test.extension.Inject;
import org.neo4j.test.extension.RandomExtension;

@ExtendWith({RandomExtension.class})
/* loaded from: input_file:org/neo4j/storageengine/api/enrichment/EnrichmentTest.class */
class EnrichmentTest {

    @Inject
    private RandomSupport random;

    EnrichmentTest() {
    }

    @Test
    void readMetadataAndPastEnrichmentData() throws IOException {
        TxMetadata create = TxMetadata.create(CaptureMode.DIFF, "some.server", securityContext(), 42L);
        byte[] nextBytes = this.random.nextBytes(new byte[128]);
        ByteBuffer flip = ByteBuffer.allocate(128).order(ByteOrder.LITTLE_ENDIAN).put(nextBytes).flip();
        ChannelBuffer channelBuffer = new ChannelBuffer(1024);
        try {
            create.serialize(channelBuffer);
            int position = (int) (((1024 - channelBuffer.position()) - 128) - 16);
            int nextInt = this.random.nextInt(13, position / 4);
            int nextInt2 = this.random.nextInt(13, position / 4);
            int nextInt3 = this.random.nextInt(13, position / 4);
            int i = ((position - nextInt) - nextInt2) - nextInt3;
            channelBuffer.m6putInt(nextInt).m6putInt(nextInt2).m6putInt(nextInt3).m6putInt(i).m1putAll(ByteBuffer.allocate(position).order(ByteOrder.LITTLE_ENDIAN).put(this.random.nextBytes(new byte[position])).flip()).m1putAll(flip).flip();
            TxMetadata readMetadataAndPastEnrichmentData = Enrichment.readMetadataAndPastEnrichmentData(channelBuffer);
            Assertions.assertThat(readMetadataAndPastEnrichmentData.captureMode()).isEqualTo(create.captureMode());
            Assertions.assertThat(readMetadataAndPastEnrichmentData.serverId()).isEqualTo(create.serverId());
            Assertions.assertThat(readMetadataAndPastEnrichmentData.subject().executingUser()).isEqualTo(create.subject().executingUser());
            Assertions.assertThat(readMetadataAndPastEnrichmentData.connectionInfo().protocol()).isEqualTo(create.connectionInfo().protocol());
            Assertions.assertThat(channelBuffer.position()).as("should have read all of the enrichment data", new Object[0]).isEqualTo(896L);
            byte[] bArr = new byte[128];
            channelBuffer.get(bArr, 128);
            ((AbstractByteArrayAssert) Assertions.assertThat(bArr).as("should read the rest of the channel OK AFTER the enrichment", new Object[0])).isEqualTo(nextBytes);
            channelBuffer.close();
        } catch (Throwable th) {
            try {
                channelBuffer.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    void concurrentReadingOfBuffers() throws Throwable {
        TxMetadata create = TxMetadata.create(CaptureMode.DIFF, "some.server", securityContext(), 42L);
        ChannelBuffer channelBuffer = new ChannelBuffer(4096);
        try {
            create.serialize(channelBuffer);
            int position = (int) ((4096 - channelBuffer.position()) - 16);
            int nextInt = this.random.nextInt(13, position / 4);
            int nextInt2 = this.random.nextInt(13, position / 4);
            int nextInt3 = this.random.nextInt(13, position / 4);
            int i = ((position - nextInt) - nextInt2) - nextInt3;
            byte[] nextBytes = this.random.nextBytes(new byte[position]);
            channelBuffer.m6putInt(nextInt).m6putInt(nextInt2).m6putInt(nextInt3).m6putInt(i).m1putAll(ByteBuffer.allocate(position).order(ByteOrder.LITTLE_ENDIAN).put(nextBytes).flip()).flip();
            Enrichment.Read deserialize = Enrichment.Read.deserialize(channelBuffer, EmptyMemoryTracker.INSTANCE);
            MutableListFactory mutableListFactory = Lists.mutable;
            Integer valueOf = Integer.valueOf(nextInt);
            Objects.requireNonNull(deserialize);
            Integer valueOf2 = Integer.valueOf(nextInt2);
            Objects.requireNonNull(deserialize);
            Integer valueOf3 = Integer.valueOf(nextInt3);
            Objects.requireNonNull(deserialize);
            Integer valueOf4 = Integer.valueOf(i);
            Objects.requireNonNull(deserialize);
            MutableList with = mutableListFactory.with(new Pair[]{Pair.of(valueOf, deserialize::entities), Pair.of(valueOf2, deserialize::entityDetails), Pair.of(valueOf3, deserialize::entityChanges), Pair.of(valueOf4, deserialize::values)});
            Runnable runnable = () -> {
                byte[] bArr = new byte[8];
                int i2 = 0;
                Iterator it = with.iterator();
                while (it.hasNext()) {
                    Pair pair = (Pair) it.next();
                    Integer num = (Integer) pair.first();
                    ByteBuffer byteBuffer = (ByteBuffer) ((Supplier) pair.other()).get();
                    for (int i3 = 0; i3 < num.intValue(); i3 += 8) {
                        int min = Math.min(8, num.intValue() - i3);
                        byteBuffer.get(bArr, 0, min);
                        byte[] copyOfRange = min == 8 ? bArr : Arrays.copyOfRange(bArr, 0, min);
                        Assertions.assertThat(copyOfRange).isEqualTo(Arrays.copyOfRange(nextBytes, i2, i2 + copyOfRange.length));
                        i2 += copyOfRange.length;
                    }
                }
            };
            Race race = new Race();
            race.addContestants(8, runnable);
            race.go();
            channelBuffer.close();
        } catch (Throwable th) {
            try {
                channelBuffer.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    void memoryTracking() throws IOException {
        ArgumentCaptor forClass = ArgumentCaptor.forClass(Long.class);
        ArgumentCaptor forClass2 = ArgumentCaptor.forClass(Long.class);
        MemoryTracker memoryTracker = (MemoryTracker) Mockito.mock(MemoryTracker.class);
        ((MemoryTracker) Mockito.doNothing().when(memoryTracker)).allocateHeap(((Long) forClass.capture()).longValue());
        ((MemoryTracker) Mockito.doNothing().when(memoryTracker)).releaseHeap(((Long) forClass2.capture()).longValue());
        TxMetadata create = TxMetadata.create(CaptureMode.DIFF, "some.server", securityContext(), 42L);
        ChannelBuffer channelBuffer = new ChannelBuffer(2048);
        try {
            create.serialize(channelBuffer);
            int position = (int) ((2048 - channelBuffer.position()) - 16);
            int nextInt = this.random.nextInt(13, position / 4);
            int nextInt2 = this.random.nextInt(13, position / 4);
            int nextInt3 = this.random.nextInt(13, position / 4);
            channelBuffer.m6putInt(nextInt).m6putInt(nextInt2).m6putInt(nextInt3).m6putInt(((position - nextInt) - nextInt2) - nextInt3).m2put(this.random.nextBytes(new byte[position]), 0, position).flip();
            Enrichment.Read deserialize = Enrichment.Read.deserialize(channelBuffer, memoryTracker);
            try {
                Assertions.assertThat(sum(forClass)).as("should have allocated the enrichment data", new Object[0]).isEqualTo(position);
                if (deserialize != null) {
                    deserialize.close();
                }
                Assertions.assertThat(sum(forClass2)).as("should have deallocated the enrichment data", new Object[0]).isEqualTo(position);
                channelBuffer.close();
            } finally {
            }
        } catch (Throwable th) {
            try {
                channelBuffer.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static SecurityContext securityContext() {
        AuthSubject subject = subject();
        SecurityContext securityContext = (SecurityContext) Mockito.mock(SecurityContext.class);
        Mockito.when(securityContext.subject()).thenReturn(subject);
        Mockito.when(securityContext.connectionInfo()).thenReturn(ClientConnectionInfo.EMBEDDED_CONNECTION);
        return securityContext;
    }

    private static AuthSubject subject() {
        AuthSubject authSubject = (AuthSubject) Mockito.mock(AuthSubject.class);
        Mockito.when(authSubject.executingUser()).thenReturn("freddy");
        return authSubject;
    }

    private static long sum(ArgumentCaptor<Long> argumentCaptor) {
        return argumentCaptor.getAllValues().stream().mapToLong(l -> {
            return l.longValue();
        }).sum();
    }
}
