package org.hepeng.commons.spring.security.filter;

import org.apache.commons.lang3.StringUtils;
import org.hepeng.commons.serializer.ObjectSerializationUtils;
import org.hepeng.commons.serializer.ObjectSerializer;
import org.hepeng.commons.serializer.SupportSerializer;
import org.hepeng.commons.spring.security.AuthenticationWrapper;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Objects;

/**
 * @author he peng
 */

public class UpstreamServiceSecurityContextFilter extends OncePerRequestFilter {

    private ObjectSerializer<SecurityContext> objectSerializer;

    public UpstreamServiceSecurityContextFilter() {
        this.objectSerializer = ObjectSerializationUtils.newObjectSerializer(SupportSerializer.HESSIAN , SecurityContext.class);
    }

    public UpstreamServiceSecurityContextFilter(ObjectSerializer objectSerializer) {
        this.objectSerializer = objectSerializer;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String securityContextHeader = request.getHeader("SecurityContext");
        boolean requestFromGateway = StringUtils.isNotBlank(securityContextHeader) ? true : false;
        try {
            if (requestFromGateway) {
                SecurityContext securityContext = objectSerializer.deserializeBase64String(securityContextHeader);
                SecurityContextHolder.setContext(securityContext);
            } else {
                SecurityContext securityContext = SecurityContextHolder.getContext();
                Authentication authentication = securityContext.getAuthentication();
                if (Objects.nonNull(securityContext) && Objects.nonNull(authentication)) {
                    if (authentication instanceof AuthenticationWrapper) {
                        SecurityContextHolder.clearContext();
                    }
                }
            }

            filterChain.doFilter(request , response);
        } finally {
            if (requestFromGateway) {
                SecurityContextHolder.clearContext();
            }
        }
    }
}
