package io.trino.plugin.postgresql.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.block.Block;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import java.util.Objects;
import java.util.Optional;
import java.util.StringJoiner;

/* loaded from: input_file:io/trino/plugin/postgresql/rule/RewriteVectorDistanceFunction.class */
public final class RewriteVectorDistanceFunction implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<ConnectorExpression> LEFT_ARGUMENT = Capture.newCapture();
    private static final Capture<ConnectorExpression> RIGHT_ARGUMENT = Capture.newCapture();
    private final Pattern<Call> pattern;
    private final String operator;

    public RewriteVectorDistanceFunction(String str, String str2) {
        this.pattern = ConnectorExpressionPatterns.call().with(ConnectorExpressionPatterns.functionName().equalTo(new FunctionName((String) Objects.requireNonNull(str, "functionName is null")))).with(ConnectorExpressionPatterns.type().matching(type -> {
            return type == DoubleType.DOUBLE;
        })).with(ConnectorExpressionPatterns.argumentCount().equalTo(2)).with(ConnectorExpressionPatterns.argument(0).matching(ConnectorExpressionPatterns.expression().capturedAs(LEFT_ARGUMENT).with(ConnectorExpressionPatterns.type().matching(RewriteVectorDistanceFunction::isArrayTypeWithRealOrDouble)))).with(ConnectorExpressionPatterns.argument(1).matching(ConnectorExpressionPatterns.expression().capturedAs(RIGHT_ARGUMENT).with(ConnectorExpressionPatterns.type().matching(RewriteVectorDistanceFunction::isArrayTypeWithRealOrDouble))));
        this.operator = (String) Objects.requireNonNull(str2, "operator is null");
    }

    public Pattern<? extends ConnectorExpression> getPattern() {
        return this.pattern;
    }

    public Optional<JdbcExpression> rewrite(ConnectorExpression connectorExpression, Captures captures, ProjectFunctionRule.RewriteContext<ParameterizedExpression> rewriteContext) {
        Optional<ParameterizedExpression> rewrite = rewrite((ConnectorExpression) captures.get(LEFT_ARGUMENT), rewriteContext);
        if (rewrite.isEmpty()) {
            return Optional.empty();
        }
        Optional<ParameterizedExpression> rewrite2 = rewrite((ConnectorExpression) captures.get(RIGHT_ARGUMENT), rewriteContext);
        return rewrite2.isEmpty() ? Optional.empty() : Optional.of(new JdbcExpression("%s %s %s".formatted(rewrite.get().expression(), this.operator, rewrite2.get().expression()), ImmutableList.builder().addAll(rewrite.get().parameters()).addAll(rewrite2.get().parameters()).build(), new JdbcTypeHandle(8, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
    }

    public static Optional<ParameterizedExpression> rewrite(ConnectorExpression connectorExpression, ProjectFunctionRule.RewriteContext<ParameterizedExpression> rewriteContext) {
        if (!(connectorExpression instanceof Constant)) {
            if (connectorExpression instanceof Call) {
                Call call = (Call) connectorExpression;
                if (call.getFunctionName().equals(StandardFunctions.CAST_FUNCTION_NAME)) {
                    Variable variable = (ConnectorExpression) Iterables.getOnlyElement(call.getArguments());
                    if (!(variable instanceof Variable)) {
                        return Optional.empty();
                    }
                    JdbcColumnHandle assignment = rewriteContext.getAssignment(variable.getName());
                    return !((Boolean) assignment.getJdbcTypeHandle().jdbcTypeName().map(str -> {
                        return Boolean.valueOf(str.equals("vector"));
                    }).orElse(false)).booleanValue() ? Optional.empty() : Optional.of(new ParameterizedExpression(quoted(assignment.getColumnName()), ImmutableList.of()));
                }
            }
            Optional rewriteExpression = rewriteContext.rewriteExpression(connectorExpression);
            return rewriteExpression.isEmpty() ? Optional.empty() : Optional.of((ParameterizedExpression) rewriteExpression.orElseThrow());
        }
        Constant constant = (Constant) connectorExpression;
        Type elementType = constant.getType().getElementType();
        Block block = (Block) constant.getValue();
        StringJoiner stringJoiner = new StringJoiner(",", "'[", "]'");
        for (int i = 0; i < block.getPositionCount(); i++) {
            if (block.isNull(i)) {
                return Optional.empty();
            }
            double d = elementType.getDouble(block, i);
            if (!isSupportedVector(d)) {
                return Optional.empty();
            }
            stringJoiner.add(Double.toString(d));
        }
        return Optional.of(new ParameterizedExpression(stringJoiner.toString(), ImmutableList.of()));
    }

    public static boolean isArrayTypeWithRealOrDouble(Type type) {
        if (type instanceof ArrayType) {
            ArrayType arrayType = (ArrayType) type;
            if (arrayType.getElementType() == RealType.REAL || arrayType.getElementType() == DoubleType.DOUBLE) {
                return true;
            }
        }
        return false;
    }

    private static boolean isSupportedVector(double d) {
        return !Double.isNaN(d) && !Double.isInfinite(d) && d >= 1.401298464324817E-45d && d <= 3.4028234663852886E38d;
    }

    private static String quoted(String str) {
        return "\"" + str.replace("\"", "\"\"") + "\"";
    }
}
