/*
 * Decompiled with CFR 0.152.
 */
package com.predic8.membrane.core.interceptor.jwt;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCChildElement;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.core.Router;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.http.Response;
import com.predic8.membrane.core.interceptor.AbstractInterceptor;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.interceptor.jwt.HeaderJwtRetriever;
import com.predic8.membrane.core.interceptor.jwt.Jwks;
import com.predic8.membrane.core.interceptor.jwt.JwtRetriever;
import java.security.Key;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import org.jose4j.base64url.Base64Url;
import org.jose4j.jwk.RsaJsonWebKey;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@MCElement(name="jwtAuth")
public class JwtAuthInterceptor
extends AbstractInterceptor {
    private static final Logger LOG = LoggerFactory.getLogger(JwtAuthInterceptor.class);
    public static final String ERROR_JWT_NOT_FOUND = "Could not retrieve JWT";
    public static final String ERROR_MALFORMED_COMPACT_SERIALIZATION = "JWTs compact serialization not valid";
    public static final String ERROR_DECODED_HEADER_NOT_JSON = "JWT header is not valid JSON";
    public static final String ERROR_NO_KID_GIVEN = "JWT does not contain a kid";
    public static final String ERROR_UNKNOWN_KEY = "JWT signed by unknown key";
    public static final String ERROR_VALIDATION_FAILED = "JWT validation failed";
    ObjectMapper mapper = new ObjectMapper();
    JwtRetriever jwtRetriever;
    Jwks jwks;
    String expectedAud;
    volatile HashMap<String, RsaJsonWebKey> kidToKey;

    @Override
    public void init(Router router) throws Exception {
        super.init(router);
        if (this.jwtRetriever == null) {
            this.jwtRetriever = new HeaderJwtRetriever("Authorization", "Bearer");
        }
        this.jwks.init(router.getResolverMap(), router.getBaseLocation());
        this.kidToKey = this.jwks.getJwks().stream().map(jwk -> {
            try {
                return new RsaJsonWebKey((Map)this.mapper.readValue(jwk.getJwk(router.getResolverMap(), router.getBaseLocation(), this.mapper), Map.class));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).collect(HashMap::new, (m, e) -> m.put(e.getKeyId(), e), (m1, m2) -> m1.putAll(m2));
        if (this.kidToKey.size() == 0) {
            throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
        }
    }

    @Override
    public Outcome handleRequest(Exchange exc) throws Exception {
        String jwt;
        try {
            jwt = this.jwtRetriever.get(exc);
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_JWT_NOT_FOUND);
        }
        return this.handleJwt(exc, jwt);
    }

    public Outcome handleJwt(Exchange exc, String jwt) {
        Map jwtClaims;
        String kid;
        Map map;
        String decode;
        if (jwt == null) {
            return this.setJsonErrorAndReturn(null, exc, 400, ERROR_JWT_NOT_FOUND);
        }
        try {
            String[] split = jwt.split(Pattern.quote("."));
            if (split.length < 3) {
                return this.setJsonErrorAndReturn(null, exc, 400, ERROR_MALFORMED_COMPACT_SERIALIZATION);
            }
            decode = new String(Base64Url.decode((String)split[0]));
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_MALFORMED_COMPACT_SERIALIZATION);
        }
        try {
            map = (Map)this.mapper.readValue(decode, Map.class);
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_DECODED_HEADER_NOT_JSON);
        }
        try {
            Object kidMaybe = map.get("kid");
            if (kidMaybe == null) {
                throw new RuntimeException();
            }
            kid = kidMaybe.toString();
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_NO_KID_GIVEN);
        }
        RsaJsonWebKey key = this.kidToKey.get(kid);
        if (key == null) {
            return this.setJsonErrorAndReturn(null, exc, 400, ERROR_UNKNOWN_KEY);
        }
        JwtConsumer jwtValidator = this.createValidator(key);
        try {
            jwtClaims = jwtValidator.processToClaims(jwt).getClaimsMap();
        }
        catch (Exception e) {
            return this.setJsonErrorAndReturn(e, exc, 400, ERROR_VALIDATION_FAILED);
        }
        exc.getProperties().put("jwt", jwtClaims);
        return Outcome.CONTINUE;
    }

    private JwtConsumer createValidator(RsaJsonWebKey key) {
        JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setRequireSubject().setVerificationKey((Key)key.getRsaPublicKey());
        if (this.expectedAud != null && !this.expectedAud.isEmpty()) {
            jwtConsumerBuilder.setExpectedAudience(new String[]{this.expectedAud});
        }
        if (this.expectedAud.equals("any!!")) {
            jwtConsumerBuilder.setSkipDefaultAudienceValidation();
        }
        JwtConsumer jwtValidator = jwtConsumerBuilder.build();
        return jwtValidator;
    }

    private Outcome setJsonErrorAndReturn(Exception e, Exchange exc, int code, String description) {
        if (e != null) {
            if (e instanceof InvalidJwtException) {
                LOG.error(e.getMessage());
            } else {
                LOG.error("", (Throwable)e);
            }
        }
        try {
            exc.setResponse(Response.ResponseBuilder.newInstance().status(code, "Bad Request").body(this.mapper.writeValueAsString((Object)ImmutableMap.builder().put((Object)"code", (Object)code).put((Object)"description", (Object)description).build())).build());
        }
        catch (JsonProcessingException jsonProcessingException) {
            throw new RuntimeException(jsonProcessingException);
        }
        return Outcome.RETURN;
    }

    public JwtRetriever getJwtRetriever() {
        return this.jwtRetriever;
    }

    @MCChildElement
    public void setJwtRetriever(JwtRetriever jwtRetriever) {
        this.jwtRetriever = jwtRetriever;
    }

    public Jwks getJwks() {
        return this.jwks;
    }

    @MCChildElement(order=1)
    public void setJwks(Jwks jwks) {
        this.jwks = jwks;
    }

    public String getExpectedAud() {
        return this.expectedAud;
    }

    @MCAttribute
    public JwtAuthInterceptor setExpectedAud(String expectedAud) {
        this.expectedAud = expectedAud;
        return this;
    }
}

