package io.trino.plugin.postgresql.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.type.DoubleType;
import java.util.Optional;

/* loaded from: input_file:io/trino/plugin/postgresql/rule/RewriteDotProductFunction.class */
public final class RewriteDotProductFunction implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<ConnectorExpression> CALL = Capture.newCapture();
    private static final Pattern<Call> PATTERN = ConnectorExpressionPatterns.call().with(ConnectorExpressionPatterns.functionName().equalTo(new FunctionName("$negate"))).with(ConnectorExpressionPatterns.type().matching(type -> {
        return type == DoubleType.DOUBLE;
    })).with(ConnectorExpressionPatterns.argumentCount().equalTo(1)).with(ConnectorExpressionPatterns.argument(0).matching(ConnectorExpressionPatterns.expression().capturedAs(CALL).matching(connectorExpression -> {
        if (connectorExpression instanceof Call) {
            Call call = (Call) connectorExpression;
            if (call.getFunctionName().equals(new FunctionName("dot_product")) && call.getArguments().size() == 2 && call.getArguments().stream().allMatch(connectorExpression -> {
                return RewriteVectorDistanceFunction.isArrayTypeWithRealOrDouble(connectorExpression.getType());
            })) {
                return true;
            }
        }
        return false;
    })));

    public Pattern<? extends ConnectorExpression> getPattern() {
        return PATTERN;
    }

    public Optional<JdbcExpression> rewrite(ConnectorExpression connectorExpression, Captures captures, ProjectFunctionRule.RewriteContext<ParameterizedExpression> rewriteContext) {
        ConnectorExpression connectorExpression2 = (ConnectorExpression) captures.get(CALL);
        Optional<ParameterizedExpression> rewrite = RewriteVectorDistanceFunction.rewrite((ConnectorExpression) connectorExpression2.getChildren().getFirst(), rewriteContext);
        if (rewrite.isEmpty()) {
            return Optional.empty();
        }
        Optional<ParameterizedExpression> rewrite2 = RewriteVectorDistanceFunction.rewrite((ConnectorExpression) connectorExpression2.getChildren().get(1), rewriteContext);
        return rewrite2.isEmpty() ? Optional.empty() : Optional.of(new JdbcExpression("%s <#> %s".formatted(rewrite.get().expression(), rewrite2.get().expression()), ImmutableList.builder().addAll(rewrite.get().parameters()).addAll(rewrite2.get().parameters()).build(), new JdbcTypeHandle(8, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
    }
}
