/*
 * Decompiled with CFR 0.152.
 */
package com.predic8.membrane.core.interceptor.xmlprotection;

import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.core.exceptions.ProblemDetails;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.interceptor.AbstractInterceptor;
import com.predic8.membrane.core.interceptor.Interceptor;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.interceptor.xmlprotection.XMLProtector;
import java.io.ByteArrayOutputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@MCElement(name="xmlProtection")
public class XMLProtectionInterceptor
extends AbstractInterceptor {
    private static final Logger log = LoggerFactory.getLogger((String)XMLProtectionInterceptor.class.getName());
    public static final String X_PROTECTION = "X-Protection";
    private int maxAttributeCount = 1000;
    private int maxElementNameLength = 1000;
    private boolean removeDTD = true;

    public XMLProtectionInterceptor() {
        this.name = "xml protection";
        this.setFlow(Interceptor.Flow.Set.REQUEST_FLOW);
    }

    @Override
    public Outcome handleRequest(Exchange exc) {
        try {
            return this.handleInternal(exc);
        }
        catch (Exception e) {
            log.error("", (Throwable)e);
            ProblemDetails.user(this.router.isProduction(), this.getDisplayName()).detail("Error inspecting body!").exception(e).buildAndSetResponse(exc);
            return Outcome.ABORT;
        }
    }

    private Outcome handleInternal(Exchange exc) throws Exception {
        if (exc.getRequest().isBodyEmpty()) {
            log.info("body is empty -> request is not scanned");
            return Outcome.CONTINUE;
        }
        if (!exc.getRequest().isXML()) {
            String msg = "Content-Type %s was not XML.".formatted(exc.getRequest().getHeader().getContentType());
            log.warn(msg);
            ProblemDetails.user(this.router.isProduction(), this.getDisplayName()).title("Request discarded by xmlProtection").detail(msg).buildAndSetResponse(exc);
            return Outcome.ABORT;
        }
        if (!this.protectXML(exc)) {
            String msg = "Request was rejected by XML protection. Please check XML.";
            log.warn(msg);
            ProblemDetails.security(this.router.isProduction(), this.getDisplayName()).title("Content violates XML security policy").detail(msg).buildAndSetResponse(exc);
            exc.getResponse().getHeader().add(X_PROTECTION, "Content violates XML security policy");
            return Outcome.ABORT;
        }
        log.debug("protected against XML attacks");
        return Outcome.CONTINUE;
    }

    private boolean protectXML(Exchange exc) throws Exception {
        ByteArrayOutputStream stream = new ByteArrayOutputStream();
        XMLProtector protector = new XMLProtector(new OutputStreamWriter((OutputStream)stream, this.getCharset(exc)), this.removeDTD, this.maxElementNameLength, this.maxAttributeCount);
        if (!protector.protect(new InputStreamReader(exc.getRequest().getBodyAsStreamDecoded(), this.getCharset(exc)))) {
            return false;
        }
        exc.getRequest().setBodyContent(stream.toByteArray());
        return true;
    }

    private String getCharset(Exchange exc) {
        String charset = exc.getRequest().getCharset();
        if (charset == null) {
            return StandardCharsets.UTF_8.name();
        }
        return charset;
    }

    @MCAttribute
    public void setMaxAttributeCount(int maxAttributeCount) {
        this.maxAttributeCount = maxAttributeCount;
    }

    public int getMaxAttributeCount() {
        return this.maxAttributeCount;
    }

    @MCAttribute
    public void setMaxElementNameLength(int maxElementNameLength) {
        this.maxElementNameLength = maxElementNameLength;
    }

    public int getMaxElementNameLength() {
        return this.maxElementNameLength;
    }

    @MCAttribute
    public void setRemoveDTD(boolean removeDTD) {
        this.removeDTD = removeDTD;
    }

    public boolean isRemoveDTD() {
        return this.removeDTD;
    }

    @Override
    public String getShortDescription() {
        return "Protects against XML attacks.";
    }
}

