/*
 * Decompiled with CFR 0.152.
 */
package org.infinispan.remoting.transport.jgroups;

import java.nio.file.Path;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import org.infinispan.commons.io.ByteBuffer;
import org.infinispan.commons.io.ByteBufferImpl;
import org.infinispan.commons.util.concurrent.CompletableFutures;
import org.infinispan.configuration.global.GlobalConfiguration;
import org.infinispan.remoting.transport.jgroups.JGroupsStateMachineAdapter;
import org.infinispan.remoting.transport.jgroups.JGroupsTransport;
import org.infinispan.remoting.transport.raft.RaftChannel;
import org.infinispan.remoting.transport.raft.RaftChannelConfiguration;
import org.infinispan.remoting.transport.raft.RaftManager;
import org.infinispan.remoting.transport.raft.RaftStateMachine;
import org.infinispan.util.logging.Log;
import org.infinispan.util.logging.LogFactory;
import org.jgroups.JChannel;
import org.jgroups.fork.ForkChannel;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.raft.ELECTION;
import org.jgroups.protocols.raft.FileBasedLog;
import org.jgroups.protocols.raft.InMemoryLog;
import org.jgroups.protocols.raft.NO_DUPES;
import org.jgroups.protocols.raft.RAFT;
import org.jgroups.protocols.raft.REDIRECT;
import org.jgroups.raft.RaftHandle;
import org.jgroups.raft.StateMachine;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;

class JGroupsRaftManager
implements RaftManager {
    private static final Log log = LogFactory.getLog(JGroupsRaftManager.class);
    private final JChannel mainChannel;
    private final Collection<String> raftMembers;
    private final String raftId;
    private final String persistenceDirectory;
    private final Map<String, JgroupsRaftChannel<? extends RaftStateMachine>> raftStateMachineMap = new ConcurrentHashMap<String, JgroupsRaftChannel<? extends RaftStateMachine>>(16);

    JGroupsRaftManager(GlobalConfiguration globalConfiguration, JChannel mainChannel) {
        if (JGroupsTransport.findFork(mainChannel) == null) {
            throw log.forkProtocolRequired();
        }
        this.mainChannel = mainChannel;
        this.raftMembers = globalConfiguration.transport().raftMembers();
        this.raftId = globalConfiguration.transport().nodeName();
        this.persistenceDirectory = globalConfiguration.globalState().enabled() ? globalConfiguration.globalState().persistentLocation() : null;
    }

    @Override
    public <T extends RaftStateMachine> T getOrRegisterStateMachine(String channelName, Supplier<T> supplier, RaftChannelConfiguration configuration) {
        Objects.requireNonNull(channelName);
        Objects.requireNonNull(supplier);
        Objects.requireNonNull(configuration);
        JgroupsRaftChannel raftChannel = this.raftStateMachineMap.computeIfAbsent(channelName, s2 -> this.createRaftChannel((String)s2, configuration, supplier));
        return raftChannel == null ? null : (T)raftChannel.stateMachine();
    }

    @Override
    public boolean isRaftAvailable() {
        return true;
    }

    @Override
    public boolean hasLeader(String channelName) {
        JgroupsRaftChannel<? extends RaftStateMachine> raftChannel = this.raftStateMachineMap.get(channelName);
        return raftChannel != null && raftChannel.raftHandle.leader() != null;
    }

    @Override
    public String raftId() {
        return this.raftId;
    }

    private <T extends RaftStateMachine> JgroupsRaftChannel<T> createRaftChannel(String name, RaftChannelConfiguration configuration, Supplier<? extends T> supplier) {
        ForkChannel forkChannel = null;
        try {
            forkChannel = this.createForkChannel(name, configuration);
            forkChannel.connect(name);
        }
        catch (Exception e) {
            log.errorCreatingForkChannel(name, e);
            if (forkChannel != null) {
                forkChannel.disconnect();
            } else {
                JGroupsTransport.findFork(this.mainChannel).remove(name);
            }
            return null;
        }
        RaftStateMachine stateMachine = (RaftStateMachine)supplier.get();
        JgroupsRaftChannel raftChannel = new JgroupsRaftChannel(name, forkChannel, stateMachine);
        stateMachine.init(raftChannel);
        return raftChannel;
    }

    private ForkChannel createForkChannel(String name, RaftChannelConfiguration configuration) throws Exception {
        RAFT raftProtocol = new RAFT();
        switch (configuration.logMode()) {
            case VOLATILE: {
                raftProtocol.logClass(InMemoryLog.class.getCanonicalName()).logPrefix(name + "-" + this.raftId);
                break;
            }
            case PERSISTENT: {
                if (this.persistenceDirectory == null) {
                    throw log.raftGlobalStateDisabled();
                }
                raftProtocol.logClass(FileBasedLog.class.getCanonicalName()).logPrefix(Path.of(this.persistenceDirectory, name, this.raftId).toAbsolutePath().toString());
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
        raftProtocol.members(this.raftMembers).raftId(this.raftId);
        return new ForkChannel(this.mainChannel, name, name, new Protocol[]{new ELECTION(), raftProtocol, new REDIRECT()});
    }

    @Override
    public void start() {
        ProtocolStack protocolStack = this.mainChannel.getProtocolStack();
        if (protocolStack.findProtocol((Class<? extends Protocol>)NO_DUPES.class) != null) {
            return;
        }
        GMS gms = (GMS)protocolStack.findProtocol((Class<? extends Protocol>)GMS.class);
        if (gms == null) {
            return;
        }
        protocolStack.insertProtocolInStack((Protocol)new NO_DUPES(), gms, ProtocolStack.Position.BELOW);
    }

    @Override
    public void stop() {
        this.raftStateMachineMap.values().forEach(JgroupsRaftChannel::disconnect);
        this.raftStateMachineMap.clear();
    }

    private static class JgroupsRaftChannel<T extends RaftStateMachine>
    implements RaftChannel {
        private final RaftHandle raftHandle;
        private final String channelName;
        private final JChannel forkedChannel;

        JgroupsRaftChannel(String channelName, JChannel forkedChannel, RaftStateMachine stateMachine) {
            this.channelName = channelName;
            this.forkedChannel = forkedChannel;
            this.raftHandle = new RaftHandle(forkedChannel, new JGroupsStateMachineAdapter<RaftStateMachine>(stateMachine));
        }

        @Override
        public CompletionStage<ByteBuffer> send(ByteBuffer buffer) {
            try {
                return this.raftHandle.setAsync(buffer.getBuf(), buffer.getOffset(), buffer.getLength()).thenApply(ByteBufferImpl::create);
            }
            catch (Exception e) {
                return CompletableFutures.completedExceptionFuture(e);
            }
        }

        @Override
        public String channelName() {
            return this.channelName;
        }

        @Override
        public String raftId() {
            return this.raftHandle.raftId();
        }

        T stateMachine() {
            StateMachine stateMachine = this.raftHandle.stateMachine();
            assert (stateMachine instanceof JGroupsStateMachineAdapter);
            return ((JGroupsStateMachineAdapter)stateMachine).getStateMachine();
        }

        void disconnect() {
            this.forkedChannel.disconnect();
        }
    }
}

