package org.apereo.cas.mgmt;

import java.io.InputStream;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import javax.xml.parsers.DocumentBuilderFactory;
import lombok.Generated;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpResponse;
import org.apereo.cas.configuration.CasConfigurationProperties;
import org.apereo.cas.configuration.CasManagementConfigurationProperties;
import org.apereo.cas.configuration.model.support.saml.idp.metadata.SamlIdPMetadataProperties;
import org.apereo.cas.services.UnauthorizedServiceException;
import org.apereo.cas.support.saml.OpenSamlConfigBean;
import org.apereo.cas.util.EncodingUtils;
import org.apereo.cas.util.HttpUtils;
import org.opensaml.saml.metadata.resolver.filter.FilterException;
import org.opensaml.saml.metadata.resolver.filter.MetadataFilter;
import org.opensaml.saml.metadata.resolver.filter.MetadataFilterContext;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpMethod;
import org.springframework.scheduling.annotation.Scheduled;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import org.xml.sax.InputSource;

/* loaded from: input_file:org/apereo/cas/mgmt/InCommonMetadataAggregateResolver.class */
public class InCommonMetadataAggregateResolver implements MetadataAggregateResolver {

    @Generated
    private static final Logger LOGGER = LoggerFactory.getLogger(InCommonMetadataAggregateResolver.class);
    private final CasConfigurationProperties casProperties;
    private final CasManagementConfigurationProperties mgmtProperties;
    private final OpenSamlConfigBean configBean;
    private final MetadataFilter signatureValidationFilter;
    private List<String> sps;

    public InCommonMetadataAggregateResolver(CasConfigurationProperties casConfigurationProperties, CasManagementConfigurationProperties casManagementConfigurationProperties, OpenSamlConfigBean openSamlConfigBean, MetadataFilter metadataFilter) {
        this.casProperties = casConfigurationProperties;
        this.mgmtProperties = casManagementConfigurationProperties;
        this.configBean = openSamlConfigBean;
        this.signatureValidationFilter = metadataFilter;
        reloadInCommon();
    }

    @Override // org.apereo.cas.mgmt.MetadataAggregateResolver
    public List<String> query(String str) {
        return (List) this.sps.stream().filter(str2 -> {
            return str2.contains(str);
        }).collect(Collectors.toList());
    }

    @Override // org.apereo.cas.mgmt.MetadataAggregateResolver
    public String location() {
        return this.mgmtProperties.getInCommonMDQUrl() + "/{0}";
    }

    @Override // org.apereo.cas.mgmt.MetadataAggregateResolver
    public EntityDescriptor find(String str) throws SignatureException {
        EntityDescriptor fromXML = MetadataUtil.fromXML(xml(str), this.configBean);
        try {
            this.signatureValidationFilter.filter(fromXML, new MetadataFilterContext());
            return fromXML;
        } catch (FilterException e) {
            LOGGER.error(e.getMessage(), e);
            throw new SignatureException("Invalid metadata signature for [" + str + "]");
        }
    }

    @Override // org.apereo.cas.mgmt.MetadataAggregateResolver
    public String xml(String str) {
        if (this.sps.contains(str)) {
            try {
                InputStream content = fetchMetadata(this.mgmtProperties.getInCommonMDQUrl() + "/" + EncodingUtils.urlEncode(str)).getEntity().getContent();
                try {
                    String iOUtils = IOUtils.toString(content, StandardCharsets.UTF_8);
                    if (content != null) {
                        content.close();
                    }
                    return iOUtils;
                } catch (Throwable th) {
                    if (content != null) {
                        try {
                            content.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Exception e) {
                LOGGER.error(e.getMessage(), e);
            }
        }
        throw new IllegalArgumentException("Entity not found");
    }

    @Scheduled(fixedDelayString = "PT60M")
    private void reloadInCommon() {
        this.sps = fromInCommon();
    }

    private HttpResponse fetchMetadata(String str) {
        SamlIdPMetadataProperties metadata = this.casProperties.getAuthn().getSamlIdp().getMetadata();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("Content-Type", metadata.getMdq().getSupportedContentTypes());
        linkedHashMap.put("Accept", "*/*");
        LOGGER.debug("Fetching dynamic metadata via MDQ for [{}]", str);
        HttpResponse execute = HttpUtils.execute(HttpUtils.HttpExecutionRequest.builder().url(str).basicAuthUsername(metadata.getMdq().getBasicAuthnUsername()).basicAuthPassword(metadata.getMdq().getBasicAuthnPassword()).parameters(new HashMap()).headers(linkedHashMap).method(HttpMethod.GET).build());
        if (execute != null) {
            return execute;
        }
        LOGGER.error("Unable to fetch metadata from [{}]", str);
        throw new UnauthorizedServiceException("screen.service.error.message");
    }

    private List<String> fromInCommon() {
        if (!StringUtils.isNotBlank(this.mgmtProperties.getInCommonMDQUrl())) {
            return new ArrayList(0);
        }
        Document parse = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(new InputSource(new StringReader(IOUtils.toString(fetchMetadata(this.mgmtProperties.getInCommonMDQUrl()).getEntity().getContent(), StandardCharsets.UTF_8))));
        ArrayList arrayList = new ArrayList();
        NodeList elementsByTagName = parse.getDocumentElement().getElementsByTagName("EntityDescriptor");
        for (int i = 0; i < elementsByTagName.getLength(); i++) {
            if (((Element) elementsByTagName.item(i)).getElementsByTagName("SPSSODescriptor").getLength() > 0) {
                arrayList.add(((Element) elementsByTagName.item(i)).getAttribute("entityID"));
            }
        }
        return arrayList;
    }
}
