/*
 * Decompiled with CFR 0.152.
 */
package com.networknt.security;

import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.networknt.client.oauth.OauthHelper;
import com.networknt.client.oauth.SignKeyRequest;
import com.networknt.client.oauth.TokenKeyRequest;
import com.networknt.config.Config;
import com.networknt.exception.ExpiredTokenException;
import com.networknt.utility.FingerPrintUtil;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.NumericDate;
import org.jose4j.jwt.consumer.ErrorCodeValidator;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;
import org.jose4j.jwx.JsonWebStructure;
import org.jose4j.keys.resolvers.JwksVerificationKeyResolver;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import org.jose4j.keys.resolvers.X509VerificationKeyResolver;
import org.owasp.encoder.Encode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JwtVerifier {
    static final Logger logger = LoggerFactory.getLogger(JwtVerifier.class);
    public static final String KID = "kid";
    public static final String SECURITY_CONFIG = "security";
    public static final String JWT_CONFIG = "jwt";
    public static final String JWT_CERTIFICATE = "certificate";
    public static final String JWT_CLOCK_SKEW_IN_SECONDS = "clockSkewInSeconds";
    public static final String ENABLE_VERIFY_JWT = "enableVerifyJwt";
    private static final String ENABLE_JWT_CACHE = "enableJwtCache";
    private static final String BOOTSTRAP_FROM_KEY_SERVICE = "bootstrapFromKeyService";
    private static final int CACHE_EXPIRED_IN_MINUTES = 15;
    public static final String JWT_KEY_RESOLVER = "keyResolver";
    public static final String JWT_KEY_RESOLVER_X509CERT = "X509Certificate";
    public static final String JWT_KEY_RESOLVER_JWKS = "JsonWebKeySet";
    Map<String, Object> config;
    Map<String, Object> jwtConfig;
    int secondsOfAllowedClockSkew;
    Boolean enableJwtCache;
    Boolean bootstrapFromKeyService;
    static Cache<String, JwtClaims> cache;
    static Map<String, X509Certificate> certMap;
    static Map<String, List<JsonWebKey>> jwksMap;
    static List<String> fingerPrints;

    public JwtVerifier(Map<String, Object> config) {
        this.config = config;
        this.jwtConfig = (Map)config.get(JWT_CONFIG);
        this.secondsOfAllowedClockSkew = (Integer)this.jwtConfig.get(JWT_CLOCK_SKEW_IN_SECONDS);
        this.bootstrapFromKeyService = (Boolean)config.get(BOOTSTRAP_FROM_KEY_SERVICE);
        this.enableJwtCache = (Boolean)config.get(ENABLE_JWT_CACHE);
        if (Boolean.TRUE.equals(this.enableJwtCache)) {
            cache = Caffeine.newBuilder().expireAfterWrite(15L, TimeUnit.MINUTES).build();
        }
        switch ((String)this.jwtConfig.getOrDefault(JWT_KEY_RESOLVER, JWT_KEY_RESOLVER_X509CERT)) {
            case "JsonWebKeySet": {
                jwksMap = new HashMap<String, List<JsonWebKey>>();
                break;
            }
            case "X509Certificate": {
                if (this.bootstrapFromKeyService != null && !Boolean.FALSE.equals(this.bootstrapFromKeyService)) break;
                certMap = new HashMap<String, X509Certificate>();
                fingerPrints = new ArrayList<String>();
                if (this.jwtConfig.get(JWT_CERTIFICATE) == null) break;
                Map keyMap = (Map)this.jwtConfig.get(JWT_CERTIFICATE);
                for (String kid : keyMap.keySet()) {
                    X509Certificate cert = null;
                    try {
                        cert = this.readCertificate((String)keyMap.get(kid));
                    }
                    catch (Exception e) {
                        logger.error("Exception:", e);
                    }
                    certMap.put(kid, cert);
                    fingerPrints.add(FingerPrintUtil.getCertFingerPrint(cert));
                }
                break;
            }
            default: {
                logger.info("{} not found or not recognized in jwt config. Use {} as default {}", JWT_KEY_RESOLVER, JWT_KEY_RESOLVER_X509CERT, JWT_KEY_RESOLVER);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public X509Certificate readCertificate(String filename) throws Exception {
        InputStream inStream = null;
        X509Certificate cert = null;
        try {
            inStream = Config.getInstance().getInputStreamFromFile(filename);
            if (inStream != null) {
                CertificateFactory cf = CertificateFactory.getInstance("X.509");
                cert = (X509Certificate)cf.generateCertificate(inStream);
            } else {
                logger.info("Certificate " + Encode.forJava(filename) + " not found.");
            }
        }
        catch (Exception e) {
            logger.error("Exception: ", e);
        }
        finally {
            if (inStream != null) {
                try {
                    inStream.close();
                }
                catch (IOException ioe) {
                    logger.error("Exception: ", ioe);
                }
            }
        }
        return cert;
    }

    public static String getJwtFromAuthorization(String authorization) {
        String[] parts;
        String jwt = null;
        if (authorization != null && (parts = authorization.split(" ")).length == 2) {
            String scheme = parts[0];
            String credentials = parts[1];
            Pattern pattern = Pattern.compile("^Bearer$", 2);
            if (pattern.matcher(scheme).matches()) {
                jwt = credentials;
            }
        }
        return jwt;
    }

    public JwtClaims verifyJwt(String jwt, boolean ignoreExpiry, boolean isToken) throws InvalidJwtException, ExpiredTokenException {
        return this.verifyJwt(jwt, ignoreExpiry, isToken, this::getKeyResolver);
    }

    public JwtClaims verifyJwt(String jwt, boolean ignoreExpiry, boolean isToken, BiFunction<String, Boolean, VerificationKeyResolver> getKeyResolver) throws InvalidJwtException, ExpiredTokenException {
        JwtClaims claims;
        if (Boolean.TRUE.equals(this.enableJwtCache) && (claims = cache.getIfPresent(jwt)) != null) {
            if (!ignoreExpiry) {
                try {
                    if (NumericDate.now().getValue() - (long)this.secondsOfAllowedClockSkew >= claims.getExpirationTime().getValue()) {
                        logger.info("Cached jwt token is expired!");
                        throw new ExpiredTokenException("Token is expired");
                    }
                }
                catch (MalformedClaimException e) {
                    logger.error("MalformedClaimException:", e);
                }
            }
            return claims;
        }
        JwtConsumer consumer = new JwtConsumerBuilder().setSkipAllValidators().setDisableRequireSignature().setSkipSignatureVerification().build();
        JwtContext jwtContext = consumer.process(jwt);
        claims = jwtContext.getJwtClaims();
        JsonWebStructure structure = jwtContext.getJoseObjects().get(0);
        String kid = structure.getKeyIdHeaderValue();
        if (!ignoreExpiry) {
            try {
                if (NumericDate.now().getValue() - (long)this.secondsOfAllowedClockSkew >= claims.getExpirationTime().getValue()) {
                    logger.info("jwt token is expired!");
                    throw new ExpiredTokenException("Token is expired");
                }
            }
            catch (MalformedClaimException e) {
                logger.error("MalformedClaimException:", e);
                throw new InvalidJwtException("MalformedClaimException", new ErrorCodeValidator.Error(18, "Invalid ExpirationTime Format"), e, jwtContext);
            }
        }
        consumer = new JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(315360000).setSkipDefaultAudienceValidation().setVerificationKeyResolver(getKeyResolver.apply(kid, isToken)).build();
        jwtContext = consumer.process(jwt);
        claims = jwtContext.getJwtClaims();
        if (Boolean.TRUE.equals(this.enableJwtCache)) {
            cache.put(jwt, claims);
        }
        return claims;
    }

    private VerificationKeyResolver getKeyResolver(String kid, boolean isToken) {
        String keyResolver;
        VerificationKeyResolver verificationKeyResolver = null;
        switch (keyResolver = (String)this.jwtConfig.getOrDefault(JWT_KEY_RESOLVER, JWT_KEY_RESOLVER_X509CERT)) {
            default: {
                X509Certificate certificate;
                X509Certificate x509Certificate = certificate = certMap == null ? null : certMap.get(kid);
                if (certificate == null) {
                    X509Certificate x509Certificate2 = certificate = isToken ? this.getCertForToken(kid) : this.getCertForSign(kid);
                    if (certMap == null) {
                        certMap = new HashMap<String, X509Certificate>();
                    }
                    certMap.put(kid, certificate);
                } else {
                    logger.debug("Got raw certificate for kid: {} from local cache", (Object)kid);
                }
                X509VerificationKeyResolver x509VerificationKeyResolver = new X509VerificationKeyResolver(certificate);
                x509VerificationKeyResolver.setTryAllOnNoThumbHeader(true);
                verificationKeyResolver = x509VerificationKeyResolver;
                break;
            }
            case "JsonWebKeySet": {
                List<JsonWebKey> jwkList;
                List<JsonWebKey> list = jwkList = jwksMap == null ? null : jwksMap.get(kid);
                if (jwkList == null) {
                    jwkList = this.getJsonWebKeySetForToken(kid);
                    if (jwkList != null) {
                        if (jwksMap == null) {
                            jwksMap = new HashMap<String, List<JsonWebKey>>();
                        }
                        jwksMap.put(kid, jwkList);
                    }
                } else {
                    logger.debug("Got Json web key set for kid: {} from local cache", (Object)kid);
                }
                if (jwkList == null) break;
                verificationKeyResolver = new JwksVerificationKeyResolver(jwkList);
            }
        }
        return verificationKeyResolver;
    }

    private List<JsonWebKey> getJsonWebKeySetForToken(String kid) {
        TokenKeyRequest keyRequest = new TokenKeyRequest(kid);
        try {
            logger.debug("Getting Json Web Key for kid: {} from {}", (Object)kid, (Object)keyRequest.getServerUrl());
            String key = OauthHelper.getKey(keyRequest);
            logger.debug("Got Json Web Key '{}' for kid: {}", (Object)key, (Object)kid);
            return new JsonWebKeySet(key).getJsonWebKeys();
        }
        catch (Exception e) {
            logger.error("Exception: ", e);
            throw new RuntimeException(e);
        }
    }

    public X509Certificate getCertForToken(String kid) {
        X509Certificate certificate = null;
        TokenKeyRequest keyRequest = new TokenKeyRequest(kid);
        try {
            logger.warn("<Deprecated: use JsonWebKeySet instead> Getting raw certificate for kid: {} from {}", (Object)kid, (Object)keyRequest.getServerUrl());
            String key = OauthHelper.getKey(keyRequest);
            logger.warn("<Deprecated: use JsonWebKeySet instead> Got raw certificate {} for kid: {}", (Object)key, (Object)kid);
            CertificateFactory cf = CertificateFactory.getInstance("X.509");
            certificate = (X509Certificate)cf.generateCertificate(new ByteArrayInputStream(key.getBytes(StandardCharsets.UTF_8)));
        }
        catch (Exception e) {
            logger.error("Exception: ", e);
            throw new RuntimeException(e);
        }
        return certificate;
    }

    public X509Certificate getCertForSign(String kid) {
        X509Certificate certificate = null;
        SignKeyRequest keyRequest = new SignKeyRequest(kid);
        try {
            String key = OauthHelper.getKey(keyRequest);
            CertificateFactory cf = CertificateFactory.getInstance("X.509");
            certificate = (X509Certificate)cf.generateCertificate(new ByteArrayInputStream(key.getBytes(StandardCharsets.UTF_8)));
        }
        catch (Exception e) {
            logger.error("Exception: ", e);
            throw new RuntimeException(e);
        }
        return certificate;
    }

    public List getFingerPrints() {
        return fingerPrints;
    }
}

