package net.solarnetwork.ocpp.web.json;

import java.net.URI;
import java.nio.charset.Charset;
import java.time.Instant;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.solarnetwork.ocpp.dao.SystemUserDao;
import net.solarnetwork.ocpp.domain.ChargePointAuthorizationDetails;
import net.solarnetwork.ocpp.domain.SystemUser;
import net.solarnetwork.service.PasswordEncoder;
import net.solarnetwork.util.ObjectUtils;
import net.solarnetwork.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeInterceptor;

/* loaded from: input_file:net/solarnetwork/ocpp/web/json/OcppWebSocketHandshakeInterceptor.class */
public class OcppWebSocketHandshakeInterceptor implements HandshakeInterceptor {
    public static final String REQUEST_URI_ATTR = "requestUri";
    public static final String DEFAULT_CLIENT_ID_URI_PATTERN = "/ocpp/v16/cs/json/(.*)";
    public static final String CLIENT_ID_ATTR = "clientId";
    private static final Logger log = LoggerFactory.getLogger(OcppWebSocketHandshakeInterceptor.class);
    private final SystemUserDao systemUserDao;
    private final PasswordEncoder passwordEncoder;
    private Pattern clientIdUriPattern;
    private BiFunction<ServerHttpRequest, String, ChargePointAuthorizationDetails> clientCredentialsExtractor;
    private String fixedIdentityUsername;

    public OcppWebSocketHandshakeInterceptor(SystemUserDao systemUserDao, PasswordEncoder passwordEncoder) {
        this.systemUserDao = systemUserDao;
        this.passwordEncoder = passwordEncoder;
        setClientIdUriPattern(Pattern.compile(DEFAULT_CLIENT_ID_URI_PATTERN));
        this.clientCredentialsExtractor = this::extractBasicAuthentication;
    }

    public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception {
        List subProtocols;
        URI uri = serverHttpRequest.getURI();
        Matcher matcher = this.clientIdUriPattern.matcher(uri.getPath());
        if (!matcher.find()) {
            log.debug("OCPP handshake request rejected, client ID not found in URI path: {}", uri.getPath());
            serverHttpResponse.setStatusCode(HttpStatus.NOT_FOUND);
            didForbidChargerConnection(serverHttpRequest, null, null, String.format("Client identifier not provided in URL path [%s].", uri.getPath()));
            return false;
        }
        String group = matcher.group(1);
        SubProtocolCapable unwrap = WebSocketHandlerDecorator.unwrap(webSocketHandler);
        if ((unwrap instanceof SubProtocolCapable) && (subProtocols = unwrap.getSubProtocols()) != null && !subProtocols.isEmpty()) {
            List secWebSocketProtocol = new WebSocketHttpHeaders(serverHttpRequest.getHeaders()).getSecWebSocketProtocol();
            boolean z = false;
            if (secWebSocketProtocol != null) {
                Iterator it = secWebSocketProtocol.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    if (subProtocols.contains((String) it.next())) {
                        z = true;
                        break;
                    }
                }
            }
            if (!z) {
                log.debug("OCPP handshake request rejected, supported sub-protocol(s) {}, requested: {}", subProtocols, secWebSocketProtocol);
                serverHttpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
                didForbidChargerConnection(serverHttpRequest, group, null, String.format("WebSocket sub-protocols [%s] provided but only [%s] supported.", StringUtils.commaDelimitedStringFromCollection(secWebSocketProtocol), StringUtils.commaDelimitedStringFromCollection(subProtocols)));
                return false;
            }
        }
        if (this.systemUserDao == null) {
            return true;
        }
        ChargePointAuthorizationDetails apply = this.clientCredentialsExtractor.apply(serverHttpRequest, group);
        if (apply == null) {
            log.warn("OCPP handshake request rejected for {}, invalid Authorization provided", group);
            serverHttpResponse.setStatusCode(HttpStatus.FORBIDDEN);
            return false;
        }
        String username = apply.getUsername();
        String password = apply.getPassword();
        SystemUser forUsernameAndChargePoint = this.systemUserDao.getForUsernameAndChargePoint(username, group);
        if (forUsernameAndChargePoint == null) {
            log.warn("OCPP handshake request rejected for {}, system user {} not found.", group, username);
            didForbidChargerConnection(serverHttpRequest, group, apply, String.format("System user [%s] not available, or not allowed for [%s].", username, group));
            serverHttpResponse.setStatusCode(HttpStatus.FORBIDDEN);
            return false;
        }
        Set allowedChargePoints = forUsernameAndChargePoint.getAllowedChargePoints();
        if (allowedChargePoints != null && !allowedChargePoints.isEmpty() && !allowedChargePoints.contains(group)) {
            log.warn("OCPP handshake request rejected for {}, system user {} does not allow identifier.", group, username);
            serverHttpResponse.setStatusCode(HttpStatus.FORBIDDEN);
            didForbidChargerConnection(serverHttpRequest, group, forUsernameAndChargePoint, String.format("System user [%s] does not allow identifier [%s]", username, group));
            return false;
        }
        if (forUsernameAndChargePoint.getPassword() == null || ((this.passwordEncoder != null && this.passwordEncoder.matches(password, forUsernameAndChargePoint.getPassword())) || forUsernameAndChargePoint.getPassword().equals(password))) {
            map.putIfAbsent(CLIENT_ID_ATTR, forUsernameAndChargePoint.chargePointIdentity(group));
            return true;
        }
        log.warn("OCPP handshake request rejected for {}, system user {} password does not match.", group, username);
        serverHttpResponse.setStatusCode(HttpStatus.FORBIDDEN);
        didForbidChargerConnection(serverHttpRequest, group, forUsernameAndChargePoint, String.format("System user [%s] password mismatch by identifier [%s].", username, group));
        return false;
    }

    protected void didForbidChargerConnection(ServerHttpRequest serverHttpRequest, String str, ChargePointAuthorizationDetails chargePointAuthorizationDetails, String str2) {
    }

    public ChargePointAuthorizationDetails extractBasicAuthentication(ServerHttpRequest serverHttpRequest, String str) {
        String first = serverHttpRequest.getHeaders().getFirst("Authorization");
        if (first == null) {
            log.warn("OCPP handshake request rejected for {}, Authorization header not provided.", str);
            didForbidChargerConnection(serverHttpRequest, str, null, "HTTP Authorization header not provided (no credentials provided).");
            return null;
        }
        String[] decodeBasicAuthorizationHeader = decodeBasicAuthorizationHeader(first);
        if (decodeBasicAuthorizationHeader != null) {
            return new SystemUser(Instant.now(), decodeBasicAuthorizationHeader[0], decodeBasicAuthorizationHeader[1]);
        }
        log.warn("OCPP handshake request rejected for {}, invalid Basic Authorization header provided: [{}]", str, first);
        didForbidChargerConnection(serverHttpRequest, str, null, "Invalid HTTP Basic Authorization header provided.");
        return null;
    }

    private static String[] decodeBasicAuthorizationHeader(String str) {
        Charset forName = Charset.forName("UTF-8");
        int indexOf = str.indexOf(32);
        if (indexOf < 0 || indexOf + 1 >= str.length()) {
            return null;
        }
        try {
            String str2 = new String(Base64.getDecoder().decode(str.substring(indexOf + 1).getBytes(forName)), forName);
            int indexOf2 = str2.indexOf(":");
            if (indexOf2 == -1) {
                return null;
            }
            return new String[]{str2.substring(0, indexOf2), str2.substring(indexOf2 + 1)};
        } catch (IllegalArgumentException e) {
            return null;
        }
    }

    public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception exc) {
    }

    public Pattern getClientIdUriPattern() {
        return this.clientIdUriPattern;
    }

    public void setClientIdUriPattern(Pattern pattern) {
        this.clientIdUriPattern = (Pattern) ObjectUtils.requireNonNullArgument(pattern, "clientIdUriPattern");
    }

    public String getFixedIdentityUsername() {
        return this.fixedIdentityUsername;
    }

    public void setFixedIdentityUsername(String str) {
        this.fixedIdentityUsername = str;
    }

    public BiFunction<ServerHttpRequest, String, ChargePointAuthorizationDetails> getClientCredentialsExtractor() {
        return this.clientCredentialsExtractor;
    }

    public void setClientCredentialsExtractor(BiFunction<ServerHttpRequest, String, ChargePointAuthorizationDetails> biFunction) {
        this.clientCredentialsExtractor = (BiFunction) ObjectUtils.requireNonNullArgument(biFunction, "clientCredentialsExtractor");
    }
}
