package edu.uiuc.ncsa.myproxy.oa4mp.oauth2.servlet;

import edu.uiuc.ncsa.myproxy.oa4mp.oauth2.OA2SE;
import edu.uiuc.ncsa.myproxy.oa4mp.server.servlet.AbstractRegistrationServlet;
import edu.uiuc.ncsa.security.delegation.storage.Client;
import edu.uiuc.ncsa.security.oauth_2_0.OA2Client;
import edu.uiuc.ncsa.security.servlet.PresentableState;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.DigestUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.StringReader;
import java.net.URI;
import java.security.SecureRandom;
import java.util.LinkedList;

/**
 * <p>Created by Jeff Gaynor<br>
 * on 3/20/14 at  4:48 PM
 */
public class OA2RegistrationServlet extends AbstractRegistrationServlet {

    protected static SecureRandom random = new SecureRandom();
    public static final String CALLBACK_URI = "callbackURI";
    public static final String REFRESH_TOKEN_LIFETIME = "rtLifetime";

    protected OA2SE getOA2SE(){
        return (OA2SE) getServiceEnvironment();
    }
    @Override
    public void prepare(PresentableState state) throws Throwable {
        super.prepare(state);
        HttpServletRequest request = state.getRequest();

        if (state.getState() == INITIAL_STATE) {
            request.setAttribute(CALLBACK_URI, CALLBACK_URI);
            request.setAttribute(REFRESH_TOKEN_LIFETIME, REFRESH_TOKEN_LIFETIME);
        }
    }

    @Override
    protected Client addNewClient(HttpServletRequest request, HttpServletResponse response) throws Throwable {
        OA2Client client = (OA2Client) super.addNewClient(request, response);
        String rawCBs = getRequiredParam(request, CALLBACK_URI);
        String rawRTLifetime = getRequiredParam(request, REFRESH_TOKEN_LIFETIME);
        long clientRtLifetime = 0L;
        boolean rtLifetimeOK = true;
        if (rawRTLifetime != null && 0 < rawRTLifetime.length()) {
            try {
                clientRtLifetime = Long.parseLong(rawRTLifetime);
                if (clientRtLifetime < 0) {
                    rtLifetimeOK = false;
                } else {
                    rtLifetimeOK = true;
                }
            } catch (Throwable t) {
                // do nix...
                rtLifetimeOK = false;
            }
            if (!rtLifetimeOK) {
                info("Client requested illegal value for refresh token lifetime at registration of \"" + rawRTLifetime + "\"");
            }
        }
        client.setRtLifetime(clientRtLifetime);
        // Now generate the client secret. We generate this here:
        byte[] bytes = new byte[getOA2SE().getClientSecretLength()];
        random.nextBytes(bytes);
        String secret64 = Base64.encodeBase64URLSafeString(bytes);
        // we have to return this to the client registration ok page and store a hash of it internally
        // so we don't have a copy of it any place but the client.
        // After this is displayed the secret is actually hashed and stored.
        client.setSecret(secret64);
        BufferedReader br = new BufferedReader(new StringReader(rawCBs));
        String x = br.readLine();
        LinkedList<String> uris = new LinkedList<>();
        while (x != null) {
            try {
                URI.create(x);
                uris.add(x);
            } catch (Throwable t) {
                // skip it.
                warn("Attempt to add bad callback uri for client " + client.getIdentifierString());
            }
            x = br.readLine();
        }
        br.close();
        client.setCallbackURIs(uris);
        fireNewClientEvent(client);
        return client;
    }

    /**
     * We override this to set the client secret to be displayed at registration time.
     *
     * @param state
     * @throws Throwable
     */
    @Override
    public void present(PresentableState state) throws Throwable {
        super.present(state);
        // after all is done, do not store the secret in the database, just a hash of it.
        if (state.getState() == REQUEST_STATE) {
            if (state instanceof ClientState) {
                // we should not store the client secret in the database, just a hash of it.
                ClientState cState = (ClientState) state;
                String secret = DigestUtils.shaHex(cState.getClient().getSecret());
                cState.getClient().setSecret(secret);
                getServiceEnvironment().getClientStore().save(cState.getClient());
            } else {
                throw new IllegalStateException("Error: An instance of ClientState was expected, but got an instance of \"" + state.getClass().getName() + "\"");
            }

        }
    }
}
