package io.spiffe.svid.jwtsvid;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import io.spiffe.Algorithm;
import io.spiffe.bundle.BundleSource;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.exception.AuthorityNotFoundException;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.spiffeid.SpiffeId;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:io/spiffe/svid/jwtsvid/JwtSvid.class */
public final class JwtSvid {
    private final SpiffeId spiffeId;
    private final Set<String> audience;
    private final Date expiry;
    private final Map<String, Object> claims;
    private final String token;

    private JwtSvid(SpiffeId spiffeId, Set<String> set, Date date, Map<String, Object> map, String str) {
        this.spiffeId = spiffeId;
        this.audience = set;
        this.expiry = date;
        this.claims = map;
        this.token = str;
    }

    public static JwtSvid parseAndValidate(@NonNull String str, @NonNull BundleSource<JwtBundle> bundleSource, @NonNull Set<String> set) throws JwtSvidException, BundleNotFoundException, AuthorityNotFoundException {
        if (str == null) {
            throw new NullPointerException("token is marked non-null but is null");
        }
        if (bundleSource == null) {
            throw new NullPointerException("jwtBundleSource is marked non-null but is null");
        }
        if (set == null) {
            throw new NullPointerException("audience is marked non-null but is null");
        }
        if (StringUtils.isBlank(str)) {
            throw new IllegalArgumentException("Token cannot be blank");
        }
        SignedJWT signedJWT = getSignedJWT(str);
        JWTClaimsSet jwtClaimsSet = getJwtClaimsSet(signedJWT);
        validateAudience(jwtClaimsSet.getAudience(), set);
        Date expirationTime = jwtClaimsSet.getExpirationTime();
        validateExpiration(expirationTime);
        SpiffeId spiffeIdOfSubject = getSpiffeIdOfSubject(jwtClaimsSet);
        JwtBundle bundleForTrustDomain = bundleSource.getBundleForTrustDomain(spiffeIdOfSubject.getTrustDomain());
        String keyId = getKeyId(signedJWT.getHeader());
        verifySignature(signedJWT, bundleForTrustDomain.findJwtAuthority(keyId), signedJWT.getHeader().getAlgorithm().getName(), keyId);
        return new JwtSvid(spiffeIdOfSubject, new HashSet(jwtClaimsSet.getAudience()), expirationTime, jwtClaimsSet.getClaims(), str);
    }

    public static JwtSvid parseInsecure(@NonNull String str, @NonNull Set<String> set) throws JwtSvidException {
        if (str == null) {
            throw new NullPointerException("token is marked non-null but is null");
        }
        if (set == null) {
            throw new NullPointerException("audience is marked non-null but is null");
        }
        if (StringUtils.isBlank(str)) {
            throw new IllegalArgumentException("Token cannot be blank");
        }
        JWTClaimsSet jwtClaimsSet = getJwtClaimsSet(getSignedJWT(str));
        validateAudience(jwtClaimsSet.getAudience(), set);
        Date expirationTime = jwtClaimsSet.getExpirationTime();
        validateExpiration(expirationTime);
        return new JwtSvid(getSpiffeIdOfSubject(jwtClaimsSet), new HashSet(jwtClaimsSet.getAudience()), expirationTime, jwtClaimsSet.getClaims(), str);
    }

    public String marshal() {
        return this.token;
    }

    public Date getExpiry() {
        return new Date(this.expiry.getTime());
    }

    public Map<String, Object> getClaims() {
        return Collections.unmodifiableMap(this.claims);
    }

    public Set<String> getAudience() {
        return Collections.unmodifiableSet(this.audience);
    }

    private static JWTClaimsSet getJwtClaimsSet(SignedJWT signedJWT) {
        try {
            return signedJWT.getJWTClaimsSet();
        } catch (ParseException e) {
            throw new IllegalArgumentException("Unable to parse JWT token", e);
        }
    }

    private static SignedJWT getSignedJWT(String str) {
        try {
            return SignedJWT.parse(str);
        } catch (ParseException e) {
            throw new IllegalArgumentException("Unable to parse JWT token", e);
        }
    }

    private static void verifySignature(SignedJWT signedJWT, PublicKey publicKey, String str, String str2) throws JwtSvidException {
        try {
            if (!signedJWT.verify(getJwsVerifier(publicKey, str))) {
                throw new JwtSvidException(String.format("Signature invalid: cannot be verified with the authority with keyId=%s", str2));
            }
        } catch (ClassCastException | JOSEException e) {
            throw new JwtSvidException(String.format("Error verifying signature with the authority with keyId=%s", str2), e);
        }
    }

    private static JWSVerifier getJwsVerifier(PublicKey publicKey, String str) throws JOSEException, JwtSvidException {
        ECDSAVerifier rSASSAVerifier;
        Algorithm parse = Algorithm.parse(str);
        if (Algorithm.Family.EC.contains(parse)) {
            rSASSAVerifier = new ECDSAVerifier((ECPublicKey) publicKey);
        } else {
            if (!Algorithm.Family.RSA.contains(parse)) {
                throw new JwtSvidException(String.format("Unsupported token signature algorithm %s", str));
            }
            rSASSAVerifier = new RSASSAVerifier((RSAPublicKey) publicKey);
        }
        return rSASSAVerifier;
    }

    private static String getKeyId(JWSHeader jWSHeader) throws JwtSvidException {
        String keyID = jWSHeader.getKeyID();
        if (keyID == null) {
            throw new JwtSvidException("Token header missing key id");
        }
        if (StringUtils.isBlank(keyID)) {
            throw new JwtSvidException("Token header key id contains an empty value");
        }
        return keyID;
    }

    private static void validateExpiration(Date date) throws JwtSvidException {
        if (date == null) {
            throw new JwtSvidException("Token missing expiration claim");
        }
        if (date.before(new Date())) {
            throw new JwtSvidException("Token has expired");
        }
    }

    private static SpiffeId getSpiffeIdOfSubject(JWTClaimsSet jWTClaimsSet) throws JwtSvidException {
        String subject = jWTClaimsSet.getSubject();
        if (StringUtils.isBlank(subject)) {
            throw new JwtSvidException("Token missing subject claim");
        }
        try {
            return SpiffeId.parse(subject);
        } catch (IllegalArgumentException e) {
            throw new JwtSvidException(String.format("Subject %s cannot be parsed as a SPIFFE ID", subject), e);
        }
    }

    private static void validateAudience(List<String> list, Set<String> set) throws JwtSvidException {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", set, list));
            }
        }
    }

    @Generated
    public SpiffeId getSpiffeId() {
        return this.spiffeId;
    }

    @Generated
    public String getToken() {
        return this.token;
    }

    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof JwtSvid)) {
            return false;
        }
        JwtSvid jwtSvid = (JwtSvid) obj;
        SpiffeId spiffeId = getSpiffeId();
        SpiffeId spiffeId2 = jwtSvid.getSpiffeId();
        if (spiffeId == null) {
            if (spiffeId2 != null) {
                return false;
            }
        } else if (!spiffeId.equals(spiffeId2)) {
            return false;
        }
        Set<String> audience = getAudience();
        Set<String> audience2 = jwtSvid.getAudience();
        if (audience == null) {
            if (audience2 != null) {
                return false;
            }
        } else if (!audience.equals(audience2)) {
            return false;
        }
        Date expiry = getExpiry();
        Date expiry2 = jwtSvid.getExpiry();
        if (expiry == null) {
            if (expiry2 != null) {
                return false;
            }
        } else if (!expiry.equals(expiry2)) {
            return false;
        }
        Map<String, Object> claims = getClaims();
        Map<String, Object> claims2 = jwtSvid.getClaims();
        if (claims == null) {
            if (claims2 != null) {
                return false;
            }
        } else if (!claims.equals(claims2)) {
            return false;
        }
        String token = getToken();
        String token2 = jwtSvid.getToken();
        return token == null ? token2 == null : token.equals(token2);
    }

    @Generated
    public int hashCode() {
        SpiffeId spiffeId = getSpiffeId();
        int hashCode = (1 * 59) + (spiffeId == null ? 43 : spiffeId.hashCode());
        Set<String> audience = getAudience();
        int hashCode2 = (hashCode * 59) + (audience == null ? 43 : audience.hashCode());
        Date expiry = getExpiry();
        int hashCode3 = (hashCode2 * 59) + (expiry == null ? 43 : expiry.hashCode());
        Map<String, Object> claims = getClaims();
        int hashCode4 = (hashCode3 * 59) + (claims == null ? 43 : claims.hashCode());
        String token = getToken();
        return (hashCode4 * 59) + (token == null ? 43 : token.hashCode());
    }

    @Generated
    public String toString() {
        return "JwtSvid(spiffeId=" + getSpiffeId() + ", audience=" + getAudience() + ", expiry=" + getExpiry() + ", claims=" + getClaims() + ", token=" + getToken() + ")";
    }
}
