/*
 * Decompiled with CFR 0.152.
 */
package org.voltdb.client;

import com.google_voltpatches.common.base.Throwables;
import java.io.EOFException;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivilegedAction;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.security.auth.Subject;
import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;
import org.voltcore.network.ReverseDNSCache;
import org.voltdb.ClientResponseImpl;
import org.voltdb.client.ClientAuthHashScheme;
import org.voltdb.client.ClientResponse;
import org.voltdb.client.ProcedureInvocation;
import org.voltdb.common.Constants;
import org.voltdb.utils.SerializationHelper;

public class ConnectionUtil {
    private static final TF m_tf = new TF();
    private static final HashMap<SocketChannel, ExecutorPair> m_executors = new HashMap();
    private static final AtomicLong m_handle = new AtomicLong(Long.MIN_VALUE);
    private static final GSSManager m_gssManager = GSSManager.getInstance();

    public static byte[] getHashedPassword(String password) {
        return ConnectionUtil.getHashedPassword(ClientAuthHashScheme.HASH_SHA256, password);
    }

    public static byte[] getHashedPassword(ClientAuthHashScheme scheme, String password) {
        if (password == null) {
            return null;
        }
        MessageDigest md = null;
        try {
            md = MessageDigest.getInstance(ClientAuthHashScheme.getDigestScheme(scheme));
        }
        catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        byte[] hashedPassword = null;
        try {
            hashedPassword = md.digest(password.getBytes("UTF-8"));
        }
        catch (UnsupportedEncodingException e) {
            throw new RuntimeException("JVM doesn't support UTF-8. Please use a supported JVM", e);
        }
        return hashedPassword;
    }

    public static Object[] getAuthenticatedConnection(String host, String username, byte[] hashedPassword, int port, Subject subject, ClientAuthHashScheme scheme) throws IOException {
        String service = subject == null ? "database" : "kerberos";
        return ConnectionUtil.getAuthenticatedConnection(service, host, username, hashedPassword, port, subject, scheme);
    }

    private static Object[] getAuthenticatedConnection(String service, String host, String username, byte[] hashedPassword, int port, Subject subject, ClientAuthHashScheme scheme) throws IOException {
        InetSocketAddress address = new InetSocketAddress(host, port);
        return ConnectionUtil.getAuthenticatedConnection(service, address, username, hashedPassword, subject, scheme);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Object[] getAuthenticatedConnection(String service, InetSocketAddress addr, String username, byte[] hashedPassword, Subject subject, ClientAuthHashScheme scheme) throws IOException {
        Object[] returnArray = new Object[3];
        boolean success = false;
        if (addr.isUnresolved()) {
            throw new UnknownHostException(addr.getHostName());
        }
        SocketChannel aChannel = SocketChannel.open(addr);
        returnArray[0] = aChannel;
        assert (aChannel.isConnected());
        if (!aChannel.isConnected()) {
            throw new IOException("Failed to open host " + ReverseDNSCache.hostnameOrAddress(addr.getAddress()));
        }
        long[] retvals = new long[4];
        returnArray[1] = retvals;
        try {
            aChannel.configureBlocking(true);
            aChannel.socket().setTcpNoDelay(true);
            byte[] serviceBytes = service == null ? null : service.getBytes(Constants.UTF8ENCODING);
            byte[] usernameBytes = username == null ? null : username.getBytes(Constants.UTF8ENCODING);
            int requestSize = 4;
            requestSize += 2;
            requestSize += serviceBytes == null ? 4 : 4 + serviceBytes.length;
            requestSize += usernameBytes == null ? 4 : 4 + usernameBytes.length;
            ByteBuffer b = ByteBuffer.allocate(requestSize += hashedPassword.length);
            b.putInt(requestSize - 4);
            b.put((byte)1);
            b.put((byte)scheme.getValue());
            SerializationHelper.writeVarbinary(serviceBytes, b);
            SerializationHelper.writeVarbinary(usernameBytes, b);
            b.put(hashedPassword);
            b.flip();
            boolean successfulWrite = false;
            IOException writeException = null;
            try {
                for (int ii = 0; ii < 4 && b.hasRemaining(); ++ii) {
                    aChannel.write(b);
                }
                if (!b.hasRemaining()) {
                    successfulWrite = true;
                }
            }
            catch (IOException e) {
                writeException = e;
            }
            int read = 0;
            ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
            while (lengthBuffer.hasRemaining()) {
                read = aChannel.read(lengthBuffer);
                if (read != -1) continue;
                if (writeException != null) {
                    throw writeException;
                }
                if (!successfulWrite) {
                    throw new IOException("Unable to write authentication info to server");
                }
                throw new IOException("Authentication rejected");
            }
            lengthBuffer.flip();
            int len = lengthBuffer.getInt();
            ByteBuffer loginResponse = ByteBuffer.allocate(len);
            while (loginResponse.hasRemaining()) {
                read = aChannel.read(loginResponse);
                if (read != -1) continue;
                if (writeException != null) {
                    throw writeException;
                }
                if (!successfulWrite) {
                    throw new IOException("Unable to write authentication info to server");
                }
                throw new IOException("Authentication rejected");
            }
            loginResponse.flip();
            byte version = loginResponse.get();
            byte loginResponseCode = loginResponse.get();
            if (version == 2) {
                byte tag = loginResponseCode;
                if (subject == null) {
                    aChannel.close();
                    throw new IOException("Server requires an authenticated JAAS principal");
                }
                if (tag != 4) {
                    aChannel.close();
                    throw new IOException("Wire protocol format violation error");
                }
                String servicePrincipal = SerializationHelper.getString(loginResponse);
                loginResponse = ConnectionUtil.performAuthenticationHandShake(aChannel, subject, servicePrincipal);
                loginResponseCode = loginResponse.get();
            }
            if (loginResponseCode != 0) {
                aChannel.close();
                switch (loginResponseCode) {
                    case 1: {
                        throw new IOException("Server has too many connections");
                    }
                    case 2: {
                        throw new IOException("Connection timed out during authentication. The VoltDB server may be overloaded.");
                    }
                    case 5: {
                        throw new IOException("Export not enabled for server");
                    }
                    case 3: {
                        throw new IOException("Wire protocol format violation error");
                    }
                    case 4: {
                        throw new IOException("Failed to authenticate to rejoining node");
                    }
                }
                throw new IOException("Authentication rejected");
            }
            retvals[0] = loginResponse.getInt();
            retvals[1] = loginResponse.getLong();
            retvals[2] = loginResponse.getLong();
            retvals[3] = loginResponse.getInt();
            int buildStringLength = loginResponse.getInt();
            byte[] buildStringBytes = new byte[buildStringLength];
            loginResponse.get(buildStringBytes);
            returnArray[2] = new String(buildStringBytes, "UTF-8");
            aChannel.configureBlocking(false);
            aChannel.socket().setKeepAlive(true);
            success = true;
        }
        finally {
            if (!success) {
                aChannel.close();
            }
        }
        return returnArray;
    }

    private static final ByteBuffer performAuthenticationHandShake(final SocketChannel channel, Subject subject, final String serviceName) throws IOException {
        try {
            Subject.doAs(subject, new PrivilegedAction<GSSContext>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public GSSContext run() {
                    GSSContext context = null;
                    try {
                        Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
                        Oid krb5PrincipalNameType = new Oid("1.2.840.113554.1.2.2.1");
                        GSSName serverName = m_gssManager.createName(serviceName, krb5PrincipalNameType);
                        ByteBuffer bb = ByteBuffer.allocate(4096);
                        context = m_gssManager.createContext(serverName, krb5Oid, null, 0);
                        context.requestMutualAuth(true);
                        context.requestConf(true);
                        context.requestInteg(true);
                        int msgSize = 0;
                        bb.limit(msgSize);
                        while (!context.isEstablished()) {
                            byte[] token = context.initSecContext(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
                            if (token != null) {
                                msgSize = 6 + token.length;
                                bb.clear().limit(msgSize);
                                bb.putInt(msgSize - 4).put((byte)2).put((byte)5);
                                bb.put(token).flip();
                                while (bb.hasRemaining()) {
                                    channel.write(bb);
                                }
                            }
                            if (context.isEstablished()) continue;
                            bb.clear().limit(4);
                            while (bb.hasRemaining()) {
                                if (channel.read(bb) != -1) continue;
                                throw new EOFException();
                            }
                            bb.flip();
                            msgSize = bb.getInt();
                            if (msgSize > bb.capacity()) {
                                throw new IOException("Authentication packet exceeded alloted size");
                            }
                            if (msgSize <= 0) {
                                throw new IOException("Wire Protocol Format error 0 or negative message length prefix");
                            }
                            bb.clear().limit(msgSize);
                            while (bb.hasRemaining()) {
                                if (channel.read(bb) != -1) continue;
                                throw new EOFException();
                            }
                            bb.flip();
                            byte version = bb.get();
                            if (version != 2) {
                                throw new IOException("Encountered unexpected authentication protocol version " + version);
                            }
                            byte tag = bb.get();
                            if (tag == 5) continue;
                            throw new IOException("Encountered unexpected authentication protocol tag " + tag);
                        }
                        if (!context.getMutualAuthState()) {
                            throw new IOException("Authentication Handshake Failed");
                        }
                        context.dispose();
                        context = null;
                    }
                    catch (GSSException ex) {
                        Throwables.propagate(ex);
                    }
                    catch (IOException ex) {
                        Throwables.propagate(ex);
                    }
                    finally {
                        if (context != null) {
                            try {
                                context.dispose();
                            }
                            catch (Exception ignoreIt) {}
                        }
                    }
                    return null;
                }
            });
        }
        catch (SecurityException ex) {
            try {
                channel.close();
            }
            catch (Exception ignoreIt) {
                // empty catch block
            }
            Throwable cause = ex.getCause();
            if (cause != null && cause instanceof RuntimeException && cause.getCause() != null) {
                cause = cause.getCause();
            } else if (cause == null) {
                cause = ex;
            }
            if (cause instanceof IOException) {
                throw (IOException)IOException.class.cast(cause);
            }
            throw new IOException("Authentication Handshake Failed", cause);
        }
        ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
        while (lengthBuffer.hasRemaining()) {
            if (channel.read(lengthBuffer) != -1) continue;
            channel.close();
            throw new EOFException();
        }
        lengthBuffer.flip();
        int responseSize = lengthBuffer.getInt();
        ByteBuffer loginResponse = ByteBuffer.allocate(responseSize);
        while (loginResponse.hasRemaining()) {
            if (channel.read(loginResponse) != -1) continue;
            channel.close();
            throw new EOFException();
        }
        loginResponse.flip();
        byte version = loginResponse.get();
        if (version != 0) {
            channel.close();
            throw new IOException("Encountered unexpected version for the login response message: " + version);
        }
        return loginResponse;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void closeConnection(SocketChannel connection) throws InterruptedException, IOException {
        HashMap<SocketChannel, ExecutorPair> hashMap = m_executors;
        synchronized (hashMap) {
            ExecutorPair p = m_executors.remove(connection);
            assert (p != null);
            p.shutdown();
        }
        connection.close();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static ExecutorPair getExecutorPair(SocketChannel channel) {
        HashMap<SocketChannel, ExecutorPair> hashMap = m_executors;
        synchronized (hashMap) {
            ExecutorPair p = m_executors.get(channel);
            if (p == null) {
                p = new ExecutorPair();
                m_executors.put(channel, p);
            }
            return p;
        }
    }

    public static Future<Long> sendInvocation(SocketChannel channel, String procName, Object ... parameters) {
        ExecutorPair p = ConnectionUtil.getExecutorPair(channel);
        return ConnectionUtil.sendInvocation(p.m_writeExecutor, channel, procName, parameters);
    }

    public static Future<Long> sendInvocation(ExecutorService executor, final SocketChannel channel, final String procName, final Object ... parameters) {
        return executor.submit(new Callable<Long>(){

            @Override
            public Long call() throws Exception {
                long handle = m_handle.getAndIncrement();
                ProcedureInvocation invocation = new ProcedureInvocation(handle, procName, parameters);
                ByteBuffer buf = ByteBuffer.allocate(4 + invocation.getSerializedSize());
                buf.position(4);
                invocation.flattenToBuffer(buf);
                buf.putInt(0, buf.capacity() - 4);
                buf.flip();
                do {
                    channel.write(buf);
                    if (!buf.hasRemaining()) continue;
                    Thread.yield();
                } while (buf.hasRemaining());
                return handle;
            }
        });
    }

    public static Future<ClientResponse> readResponse(SocketChannel channel) {
        ExecutorPair p = ConnectionUtil.getExecutorPair(channel);
        return ConnectionUtil.readResponse(p.m_readExecutor, channel);
    }

    public static Future<ClientResponse> readResponse(ExecutorService executor, final SocketChannel channel) {
        return executor.submit(new Callable<ClientResponse>(){

            @Override
            public ClientResponse call() throws Exception {
                ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
                do {
                    int read;
                    if ((read = channel.read(lengthBuffer)) == -1) {
                        throw new EOFException();
                    }
                    if (!lengthBuffer.hasRemaining()) continue;
                    Thread.yield();
                } while (lengthBuffer.hasRemaining());
                lengthBuffer.flip();
                ByteBuffer message = ByteBuffer.allocate(lengthBuffer.getInt());
                do {
                    int read;
                    if ((read = channel.read(message)) == -1) {
                        throw new EOFException();
                    }
                    if (!lengthBuffer.hasRemaining()) continue;
                    Thread.yield();
                } while (message.hasRemaining());
                message.flip();
                ClientResponseImpl response = new ClientResponseImpl();
                response.initFromBuffer(message);
                return response;
            }
        });
    }

    static /* synthetic */ TF access$100() {
        return m_tf;
    }

    public static class ExecutorPair {
        public final ExecutorService m_writeExecutor = Executors.newSingleThreadExecutor(ConnectionUtil.access$100());
        public final ExecutorService m_readExecutor = Executors.newSingleThreadExecutor(ConnectionUtil.access$100());

        private void shutdown() throws InterruptedException {
            this.m_readExecutor.shutdownNow();
            this.m_writeExecutor.shutdownNow();
            this.m_readExecutor.awaitTermination(1L, TimeUnit.DAYS);
            this.m_writeExecutor.awaitTermination(1L, TimeUnit.DAYS);
        }
    }

    private static class TF
    implements ThreadFactory {
        private TF() {
        }

        @Override
        public Thread newThread(Runnable r) {
            return new Thread(null, r, "Yet another thread", 65536L);
        }
    }
}

