package pl.edu.icm.unity.rest.authn;

import com.google.common.collect.ImmutableSet;
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.log4j.MDC;
import org.apache.logging.log4j.Logger;
import pl.edu.icm.unity.MessageSource;
import pl.edu.icm.unity.base.utils.Log;
import pl.edu.icm.unity.engine.api.EntityManagement;
import pl.edu.icm.unity.engine.api.authn.AuthenticatedEntity;
import pl.edu.icm.unity.engine.api.authn.AuthenticationException;
import pl.edu.icm.unity.engine.api.authn.AuthenticationFlow;
import pl.edu.icm.unity.engine.api.authn.AuthenticationProcessor;
import pl.edu.icm.unity.engine.api.authn.AuthenticationResult;
import pl.edu.icm.unity.engine.api.authn.AuthenticatorInstance;
import pl.edu.icm.unity.engine.api.authn.DefaultUnsuccessfulAuthenticationCounter;
import pl.edu.icm.unity.engine.api.authn.InvocationContext;
import pl.edu.icm.unity.engine.api.authn.LoginSession;
import pl.edu.icm.unity.engine.api.authn.PartialAuthnState;
import pl.edu.icm.unity.engine.api.authn.UnsuccessfulAuthenticationCounter;
import pl.edu.icm.unity.engine.api.server.HTTPRequestContext;
import pl.edu.icm.unity.engine.api.session.SessionManagement;
import pl.edu.icm.unity.engine.api.utils.MDCKeys;
import pl.edu.icm.unity.exceptions.AuthorizationException;
import pl.edu.icm.unity.exceptions.EngineException;
import pl.edu.icm.unity.rest.authn.ext.TLSRetrieval;
import pl.edu.icm.unity.types.authn.AuthenticationOptionKey;
import pl.edu.icm.unity.types.authn.AuthenticationRealm;
import pl.edu.icm.unity.types.basic.EntityParam;
import pl.edu.icm.unity.types.basic.IdentityTaV;

/* loaded from: input_file:pl/edu/icm/unity/rest/authn/AuthenticationInterceptor.class */
public class AuthenticationInterceptor extends AbstractPhaseInterceptor<Message> {
    private static final Logger log = Log.getLogger("unity.server.rest", AuthenticationInterceptor.class);
    private MessageSource msg;
    private AuthenticationProcessor authenticationProcessor;
    protected List<AuthenticationFlow> authenticators;
    protected UnsuccessfulAuthenticationCounter UnsuccessfulAuthenticationCounterImpl;
    protected SessionManagement sessionMan;
    protected AuthenticationRealm realm;
    protected final Set<String> notProtectedPaths;
    protected final Set<String> optionalAuthnPaths;
    private Properties endpointProperties;
    private final EntityManagement entityMan;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:pl/edu/icm/unity/rest/authn/AuthenticationInterceptor$EntityWithAuthenticators.class */
    public static class EntityWithAuthenticators {
        private final AuthenticatedEntity entity;
        private final AuthenticationOptionKey firstFactor;
        private final AuthenticationOptionKey secondFactor;

        EntityWithAuthenticators(AuthenticatedEntity authenticatedEntity, AuthenticationOptionKey authenticationOptionKey, AuthenticationOptionKey authenticationOptionKey2) {
            this.entity = authenticatedEntity;
            this.firstFactor = authenticationOptionKey;
            this.secondFactor = authenticationOptionKey2;
        }
    }

    public AuthenticationInterceptor(MessageSource messageSource, AuthenticationProcessor authenticationProcessor, List<AuthenticationFlow> list, AuthenticationRealm authenticationRealm, SessionManagement sessionManagement, Set<String> set, Set<String> set2, Properties properties, EntityManagement entityManagement) {
        super("pre-invoke");
        this.msg = messageSource;
        this.authenticationProcessor = authenticationProcessor;
        this.realm = authenticationRealm;
        this.authenticators = list;
        this.endpointProperties = properties;
        this.entityMan = entityManagement;
        this.UnsuccessfulAuthenticationCounterImpl = new DefaultUnsuccessfulAuthenticationCounter(authenticationRealm.getBlockAfterUnsuccessfulLogins(), authenticationRealm.getBlockFor() * 1000);
        this.sessionMan = sessionManagement;
        this.notProtectedPaths = ImmutableSet.copyOf(set);
        this.optionalAuthnPaths = ImmutableSet.copyOf(set2);
    }

    public void handleMessage(Message message) throws Fault {
        String clientIP = getClientIP();
        if (this.UnsuccessfulAuthenticationCounterImpl.getRemainingBlockedTime(clientIP) > 0) {
            log.info("Authentication blocked for client with IP " + clientIP);
            throw new Fault(new Exception("Too many invalid authentication attempts, try again later"));
        }
        HashMap hashMap = new HashMap();
        X509Certificate[] tLSCertificates = TLSRetrieval.getTLSCertificates();
        InvocationContext invocationContext = new InvocationContext(tLSCertificates == null ? null : new IdentityTaV("x500Name", tLSCertificates[0].getSubjectX500Principal().getName()), this.realm, this.authenticators);
        InvocationContext.setCurrent(invocationContext);
        AuthenticationException authenticationException = null;
        EntityWithAuthenticators entityWithAuthenticators = null;
        if (isToNotProtected(message)) {
            return;
        }
        for (AuthenticationFlow authenticationFlow : this.authenticators) {
            try {
                entityWithAuthenticators = processAuthnFlow(hashMap, authenticationFlow);
                break;
            } catch (AuthenticationException e) {
                if (log.isDebugEnabled()) {
                    log.debug("Authentication set failed to authenticate the client using flow " + authenticationFlow.getId() + ", will try another: " + e);
                }
                if (authenticationException == null) {
                    authenticationException = new AuthenticationException(this.msg.getMessage(e.getMessage(), new Object[0]));
                }
            }
        }
        if (entityWithAuthenticators != null) {
            authnSuccess(entityWithAuthenticators, clientIP, invocationContext);
        } else {
            if (!isToOptionallyAuthenticatedPath(message)) {
                log.info("Authentication failed for client");
                this.UnsuccessfulAuthenticationCounterImpl.unsuccessfulAttempt(clientIP);
                throw new Fault(authenticationException == null ? new Exception("Authentication failed") : authenticationException);
            }
            log.debug("Request to an address with optional authentication - {} - invocation will proceed without authentication", message.get("org.apache.cxf.request.uri"));
        }
    }

    private boolean isToOptionallyAuthenticatedPath(Message message) {
        return isToSpecialPath(message, this.optionalAuthnPaths);
    }

    private boolean isToNotProtected(Message message) {
        boolean isToSpecialPath = isToSpecialPath(message, this.notProtectedPaths);
        if (isToSpecialPath) {
            log.debug("Request to a not protected address - {} - invocation will proceed without authentication", message.get("org.apache.cxf.request.uri"));
        }
        return isToSpecialPath;
    }

    private boolean isToSpecialPath(Message message, Set<String> set) {
        String str = (String) message.get("org.apache.cxf.request.uri");
        if (str == null) {
            log.error("Can not establish the destination address");
            return false;
        }
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            if (str.equals(it.next())) {
                return true;
            }
        }
        return false;
    }

    private void authnSuccess(EntityWithAuthenticators entityWithAuthenticators, String str, InvocationContext invocationContext) {
        if (log.isDebugEnabled()) {
            log.info("Client was successfully authenticated: [" + entityWithAuthenticators.entity.getEntityId() + "] " + entityWithAuthenticators.entity.getAuthenticatedWith().toString());
        }
        this.UnsuccessfulAuthenticationCounterImpl.successfulAttempt(str);
        LoginSession createSession = this.sessionMan.getCreateSession(entityWithAuthenticators.entity.getEntityId().longValue(), this.realm, getLabel(entityWithAuthenticators.entity.getEntityId().longValue()), entityWithAuthenticators.entity.getOutdatedCredentialId(), new LoginSession.RememberMeInfo(false, false), entityWithAuthenticators.firstFactor, entityWithAuthenticators.secondFactor);
        invocationContext.setLoginSession(createSession);
        createSession.addAuthenticatedIdentities(entityWithAuthenticators.entity.getAuthenticatedWith());
        createSession.setRemoteIdP(entityWithAuthenticators.entity.getRemoteIdP());
        MDC.put(MDCKeys.ENTITY_ID.key, Long.valueOf(createSession.getEntityId()));
        MDC.put(MDCKeys.USER.key, createSession.getEntityLabel());
    }

    private String getLabel(long j) {
        try {
            return this.entityMan.getEntityLabel(new EntityParam(Long.valueOf(j)));
        } catch (EngineException e) {
            log.error("Can not get the attribute designated with EntityName", e);
            return null;
        } catch (AuthorizationException e2) {
            log.debug("Not setting entity's label as the client is not authorized to read the attribute", e2);
            return null;
        }
    }

    private EntityWithAuthenticators processAuthnFlow(Map<String, AuthenticationResult> map, AuthenticationFlow authenticationFlow) throws AuthenticationException {
        PartialAuthnState partialAuthnState = null;
        AuthenticationException authenticationException = null;
        for (AuthenticatorInstance authenticatorInstance : authenticationFlow.getFirstFactorAuthenticators()) {
            try {
                partialAuthnState = this.authenticationProcessor.processPrimaryAuthnResult(processAuthenticator(map, (CXFAuthentication) authenticatorInstance.getRetrieval()), authenticationFlow, AuthenticationOptionKey.authenticatorOnlyKey(authenticatorInstance.getRetrieval().getAuthenticatorId()));
                break;
            } catch (AuthenticationException e) {
                if (authenticationException == null) {
                    authenticationException = new AuthenticationException(e.getMessage());
                }
            }
        }
        if (partialAuthnState == null) {
            if (authenticationException == null) {
                throw new AuthenticationException("Authentication failed");
            }
            throw authenticationException;
        }
        if (!partialAuthnState.isSecondaryAuthenticationRequired()) {
            return new EntityWithAuthenticators(this.authenticationProcessor.finalizeAfterPrimaryAuthentication(partialAuthnState, false), partialAuthnState.getFirstFactorOptionId(), null);
        }
        CXFAuthentication cXFAuthentication = (CXFAuthentication) partialAuthnState.getSecondaryAuthenticator();
        return new EntityWithAuthenticators(this.authenticationProcessor.finalizeAfterSecondaryAuthentication(partialAuthnState, processAuthenticator(map, cXFAuthentication)), partialAuthnState.getFirstFactorOptionId(), AuthenticationOptionKey.authenticatorOnlyKey(cXFAuthentication.getAuthenticatorId()));
    }

    private AuthenticationResult processAuthenticator(Map<String, AuthenticationResult> map, CXFAuthentication cXFAuthentication) throws AuthenticationException {
        String authenticatorId = cXFAuthentication.getAuthenticatorId();
        AuthenticationResult authenticationResult = map.get(authenticatorId);
        if (authenticationResult == null) {
            log.trace("Processing authenticator " + authenticatorId);
            authenticationResult = cXFAuthentication.getAuthenticationResult(this.endpointProperties);
            map.put(authenticatorId, authenticationResult);
            log.trace("Authenticator " + authenticatorId + " returned " + authenticationResult);
        } else {
            log.trace("Using cached result of " + authenticatorId + ": " + authenticationResult);
        }
        return authenticationResult;
    }

    private String getClientIP() {
        return HTTPRequestContext.getCurrent().getClientIP();
    }
}
