package org.hepeng.commons.spring.security.web.filter;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.security.Principal;
import java.util.Objects;

/**
 * @author he peng
 */
public class UpstreamRequestSkipOverSpringSecurityChainFilter extends SkipOverSpringSecurityFilterChainFilter {

    private static final String SECURITY_CONTEXT_HEADER_NAME = "UpstreamSecurityContext";
    private RedisSerializer redisSerializer = new JdkSerializationRedisSerializer();

//    private ObjectSerializer<SecurityContext> objectSerializer;

    public UpstreamRequestSkipOverSpringSecurityChainFilter() {
//        this.objectSerializer = ObjectSerializationUtils.newObjectSerializer(SupportSerializer.HESSIAN , SecurityContext.class);
    }

    /*public UpstreamRequestSkipOverSpringSecurityChainFilter(ObjectSerializer objectSerializer) {
        this.objectSerializer = objectSerializer;
    }*/

    @Override
    protected boolean isSkipOver(HttpServletRequest request) {
        String securityContextHeader = request.getHeader(SECURITY_CONTEXT_HEADER_NAME);
        return StringUtils.isNotBlank(securityContextHeader) ? true : false;
    }

    @Override
    protected HttpServletRequest wrapRequest(HttpServletRequest request) {
        return new UpstreamSecurityContextAwareRequestWrapper(request);
    }

    private class UpstreamSecurityContextAwareRequestWrapper extends HttpServletRequestWrapper {

        public UpstreamSecurityContextAwareRequestWrapper(HttpServletRequest request) {
            super(request);
        }

        @Override
        public Principal getUserPrincipal() {
            String securityContextVal = super.getHeader(SECURITY_CONTEXT_HEADER_NAME);
            if (StringUtils.isBlank(securityContextVal)) {
                return super.getUserPrincipal();
            }

            /*SecurityContext securityContext =
                    UpstreamRequestSkipOverSpringSecurityChainFilter
                            .this.objectSerializer.deserializeBase64String(securityContextVal);*/
            byte[] bytes = Base64.decodeBase64(securityContextVal);
            SecurityContext securityContext =
                    (SecurityContext) UpstreamRequestSkipOverSpringSecurityChainFilter
                            .this.redisSerializer.deserialize(bytes);

            Authentication auth = null;
            if (Objects.nonNull(securityContext)) {
                auth = securityContext.getAuthentication();
            }

            return auth;
        }
    }
}
