/*
 * Decompiled with CFR 0.152.
 */
package com.predic8.membrane.core.interceptor.jwt;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCChildElement;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.http.Response;
import com.predic8.membrane.core.interceptor.AbstractInterceptor;
import com.predic8.membrane.core.interceptor.Interceptor;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.interceptor.jwt.HeaderJwtRetriever;
import com.predic8.membrane.core.interceptor.jwt.JWTException;
import com.predic8.membrane.core.interceptor.jwt.JsonWebToken;
import com.predic8.membrane.core.interceptor.jwt.Jwks;
import com.predic8.membrane.core.interceptor.jwt.JwtRetriever;
import com.predic8.membrane.core.security.JWTSecurityScheme;
import java.security.Key;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.text.StringEscapeUtils;
import org.jose4j.jwk.RsaJsonWebKey;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@MCElement(name="jwtAuth")
public class JwtAuthInterceptor
extends AbstractInterceptor {
    private static final Logger LOG = LoggerFactory.getLogger(JwtAuthInterceptor.class);
    public static final String ERROR_JWT_NOT_FOUND = "Could not retrieve JWT";
    public static final String ERROR_DECODED_HEADER_NOT_JSON = "JWT header is not valid JSON";
    public static final String ERROR_UNKNOWN_KEY = "JWT signed by unknown key";
    public static final String ERROR_VALIDATION_FAILED = "JWT validation failed";
    ObjectMapper mapper = new ObjectMapper();
    JwtRetriever jwtRetriever;
    Jwks jwks;
    String expectedAud;
    volatile HashMap<String, RsaJsonWebKey> kidToKey;

    public JwtAuthInterceptor() {
        this.name = "jwt checker.";
        this.setFlow(EnumSet.of(Interceptor.Flow.REQUEST));
    }

    @Override
    public void init() {
        super.init();
        if (this.jwtRetriever == null) {
            this.jwtRetriever = new HeaderJwtRetriever("Authorization", "Bearer");
        }
        this.jwks.init(this.router.getResolverMap(), this.router.getBaseLocation());
        this.kidToKey = this.jwks.getJwks().stream().map(jwk -> {
            try {
                return new RsaJsonWebKey((Map)this.mapper.readValue(jwk.getJwk(this.router.getResolverMap(), this.router.getBaseLocation(), this.mapper), Map.class));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).collect(HashMap::new, (m, e) -> m.put(e.getKeyId(), e), HashMap::putAll);
        if (this.kidToKey.isEmpty()) {
            throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
        }
    }

    @Override
    public Outcome handleRequest(Exchange exc) {
        try {
            String jwt = this.jwtRetriever.get(exc);
            return this.handleJwt(exc, jwt);
        }
        catch (JWTException e) {
            return this.setJsonErrorAndReturn(e, exc, 400, e.getMessage());
        }
        catch (JsonProcessingException e) {
            return this.setJsonErrorAndReturn((Exception)((Object)e), exc, 400, ERROR_DECODED_HEADER_NOT_JSON);
        }
        catch (InvalidJwtException e) {
            return this.setJsonErrorAndReturn((Exception)((Object)e), exc, 400, ERROR_VALIDATION_FAILED);
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_JWT_NOT_FOUND);
        }
    }

    public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProcessingException, InvalidJwtException {
        if (jwt == null) {
            throw new JWTException(ERROR_JWT_NOT_FOUND);
        }
        JsonWebToken decodedJwt = new JsonWebToken(jwt);
        String kid = decodedJwt.getHeader().kid();
        if (!this.kidToKey.containsKey(kid)) {
            throw new JWTException(ERROR_UNKNOWN_KEY);
        }
        RsaJsonWebKey key = this.kidToKey.get(kid);
        Map jwtClaims = this.createValidator(key).processToClaims(jwt).getClaimsMap();
        exc.getProperties().put("jwt", jwtClaims);
        new JWTSecurityScheme(jwtClaims).add(exc);
        return Outcome.CONTINUE;
    }

    private JwtConsumer createValidator(RsaJsonWebKey key) {
        JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setRequireSubject().setVerificationKey((Key)key.getRsaPublicKey());
        if (this.acceptAnyAud()) {
            jwtConsumerBuilder.setSkipDefaultAudienceValidation();
        } else if (this.expectedAud != null && !this.expectedAud.isEmpty()) {
            jwtConsumerBuilder.setExpectedAudience(new String[]{this.expectedAud});
        }
        return jwtConsumerBuilder.build();
    }

    private boolean acceptAnyAud() {
        return this.expectedAud != null && this.expectedAud.equals("any!!");
    }

    private Outcome setJsonErrorAndReturn(Exception e, Exchange exc, int code, String description) {
        if (e != null) {
            if (e instanceof InvalidJwtException) {
                LOG.error(e.getMessage());
            } else {
                LOG.error("", (Throwable)e);
            }
        }
        try {
            exc.setResponse(Response.ResponseBuilder.newInstance().status(code, "Bad Request").body(this.mapper.writeValueAsString(Map.of("code", code, "description", description))).build());
        }
        catch (JsonProcessingException jsonProcessingException) {
            throw new RuntimeException(jsonProcessingException);
        }
        return Outcome.RETURN;
    }

    public JwtRetriever getJwtRetriever() {
        return this.jwtRetriever;
    }

    @MCChildElement
    public void setJwtRetriever(JwtRetriever jwtRetriever) {
        this.jwtRetriever = jwtRetriever;
    }

    public Jwks getJwks() {
        return this.jwks;
    }

    @MCChildElement(order=1)
    public void setJwks(Jwks jwks) {
        this.jwks = jwks;
    }

    public String getExpectedAud() {
        return this.expectedAud;
    }

    @MCAttribute
    public JwtAuthInterceptor setExpectedAud(String expectedAud) {
        this.expectedAud = expectedAud;
        return this;
    }

    @Override
    public String getShortDescription() {
        return "Checks for a valid JWT.";
    }

    @Override
    public String getLongDescription() {
        return "Checks for a valid JWT.<br/>" + (String)(this.acceptAnyAud() ? "Accepts any value for the <font style=\"font-family: monospace\">aud</font> field. <b>THIS IS STRONGLY DISCOURAGED!</b><br/>" : "Accepts <font style=\"font-family: monospace\">" + StringEscapeUtils.escapeHtml4((String)this.expectedAud) + "</font> as valid value for the <font style=\"font-family: monospace\">aud</font> payload entry.<br/>") + (String)(this.jwks != null ? "Validates the JWT signature against " + this.jwks.getLongDescription() + " ." : "");
    }
}

