package io.prestosql.plugin.sqlserver;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.prestosql.plugin.jdbc.BaseJdbcClient;
import io.prestosql.plugin.jdbc.BaseJdbcConfig;
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcErrorCode;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.PredicatePushdownController;
import io.prestosql.plugin.jdbc.SliceWriteFunction;
import io.prestosql.plugin.jdbc.StandardColumnMappings;
import io.prestosql.plugin.jdbc.WriteMapping;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter;
import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal;
import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint;
import io.prestosql.plugin.jdbc.expression.ImplementCount;
import io.prestosql.plugin.jdbc.expression.ImplementCountAll;
import io.prestosql.plugin.jdbc.expression.ImplementMinMax;
import io.prestosql.plugin.jdbc.expression.ImplementSum;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.CharType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarbinaryType;
import io.prestosql.spi.type.VarcharType;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import javax.inject.Inject;

/* loaded from: input_file:io/prestosql/plugin/sqlserver/SqlServerClient.class */
public class SqlServerClient extends BaseJdbcClient {
    private static final int SQL_SERVER_MAX_LIST_EXPRESSIONS = 500;
    private final AggregateFunctionRewriter aggregateFunctionRewriter;
    private static final Joiner DOT_JOINER = Joiner.on(".");
    private static final PredicatePushdownController SIMPLIFY_UNSUPPORTED_PUSHDOWN = domain -> {
        Domain domain = domain;
        if (domain.getValues().getRanges().getRangeCount() > SQL_SERVER_MAX_LIST_EXPRESSIONS) {
            domain = domain.simplify();
        }
        return new PredicatePushdownController.DomainPushdownResult(domain, domain);
    };

    @Inject
    public SqlServerClient(BaseJdbcConfig baseJdbcConfig, ConnectionFactory connectionFactory) {
        super(baseJdbcConfig, "\"", connectionFactory);
        JdbcTypeHandle jdbcTypeHandle = new JdbcTypeHandle(-5, Optional.of("bigint"), 0, 0, Optional.empty(), Optional.empty());
        this.aggregateFunctionRewriter = new AggregateFunctionRewriter(this::quoted, ImmutableSet.builder().add(new ImplementCountAll(jdbcTypeHandle)).add(new ImplementCount(jdbcTypeHandle)).add(new ImplementMinMax()).add(new ImplementSum(SqlServerClient::toTypeHandle)).add(new ImplementAvgFloatingPoint()).add(new ImplementAvgDecimal()).add(new ImplementAvgBigint()).add(new ImplementSqlServerStdev()).add(new ImplementSqlServerStddevPop()).add(new ImplementSqlServerVariance()).add(new ImplementSqlServerVariancePop()).build());
    }

    protected void renameTable(JdbcIdentity jdbcIdentity, String str, String str2, String str3, SchemaTableName schemaTableName) {
        if (!str2.equals(schemaTableName.getSchemaName())) {
            throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Table rename across schemas is not supported");
        }
        execute(jdbcIdentity, String.format("sp_rename %s, %s", singleQuote(str, str2, str3), singleQuote(schemaTableName.getTableName())));
    }

    public void renameColumn(JdbcIdentity jdbcIdentity, JdbcTableHandle jdbcTableHandle, JdbcColumnHandle jdbcColumnHandle, String str) {
        execute(jdbcIdentity, String.format("sp_rename %s, %s, 'COLUMN'", singleQuote(jdbcTableHandle.getCatalogName(), jdbcTableHandle.getSchemaName(), jdbcTableHandle.getTableName(), jdbcColumnHandle.getColumnName()), singleQuote(str)));
    }

    protected void copyTableSchema(Connection connection, String str, String str2, String str3, String str4, List<String> list) {
        execute(connection, String.format("SELECT %s INTO %s FROM %s WHERE 0 = 1", list.stream().map(this::quoted).collect(Collectors.joining(", ")), quoted(str, str2, str4), quoted(str, str2, str3)));
    }

    public Optional<ColumnMapping> toPrestoType(ConnectorSession connectorSession, Connection connection, JdbcTypeHandle jdbcTypeHandle) {
        Optional<ColumnMapping> forcedMappingToVarchar = getForcedMappingToVarchar(jdbcTypeHandle);
        return forcedMappingToVarchar.isPresent() ? forcedMappingToVarchar : ((String) jdbcTypeHandle.getJdbcTypeName().orElseThrow(() -> {
            return new PrestoException(JdbcErrorCode.JDBC_ERROR, "Type name is missing: " + jdbcTypeHandle);
        })).equals("varbinary") ? Optional.of(varbinaryColumnMapping()) : super.toPrestoType(connectorSession, connection, jdbcTypeHandle).map(columnMapping -> {
            return new ColumnMapping(columnMapping.getType(), columnMapping.getReadFunction(), columnMapping.getWriteFunction(), SIMPLIFY_UNSUPPORTED_PUSHDOWN);
        });
    }

    public WriteMapping toWriteMapping(ConnectorSession connectorSession, Type type) {
        if (type == BooleanType.BOOLEAN) {
            return WriteMapping.booleanMapping("bit", StandardColumnMappings.booleanWriteFunction());
        }
        if (type instanceof VarcharType) {
            VarcharType varcharType = (VarcharType) type;
            return WriteMapping.sliceMapping((varcharType.isUnbounded() || varcharType.getBoundedLength() > 4000) ? "nvarchar(max)" : "nvarchar(" + varcharType.getBoundedLength() + ")", StandardColumnMappings.varcharWriteFunction());
        }
        if (!(type instanceof CharType)) {
            return type instanceof VarbinaryType ? WriteMapping.sliceMapping("varbinary(max)", varbinaryWriteFunction()) : super.toWriteMapping(connectorSession, type);
        }
        CharType charType = (CharType) type;
        return WriteMapping.sliceMapping(charType.getLength() > 4000 ? "nvarchar(max)" : "nchar(" + charType.getLength() + ")", StandardColumnMappings.charWriteFunction());
    }

    public Optional<JdbcExpression> implementAggregation(ConnectorSession connectorSession, AggregateFunction aggregateFunction, Map<String, ColumnHandle> map) {
        return this.aggregateFunctionRewriter.rewrite(connectorSession, aggregateFunction, map);
    }

    private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType) {
        return Optional.of(new JdbcTypeHandle(2, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty()));
    }

    protected Optional<BiFunction<String, Long, String>> limitFunction() {
        return Optional.of((str, l) -> {
            Preconditions.checkArgument(str.startsWith("SELECT "));
            return "SELECT TOP " + l + " " + str.substring("SELECT ".length());
        });
    }

    public boolean isLimitGuaranteed(ConnectorSession connectorSession) {
        return true;
    }

    private static String singleQuote(String... strArr) {
        return singleQuote(DOT_JOINER.join(strArr));
    }

    private static String singleQuote(String str) {
        return "'" + str + "'";
    }

    public static ColumnMapping varbinaryColumnMapping() {
        return ColumnMapping.sliceMapping(VarbinaryType.VARBINARY, (resultSet, i) -> {
            return Slices.wrappedBuffer(resultSet.getBytes(i));
        }, varbinaryWriteFunction(), ColumnMapping.DISABLE_PUSHDOWN);
    }

    private static SliceWriteFunction varbinaryWriteFunction() {
        return new SliceWriteFunction() { // from class: io.prestosql.plugin.sqlserver.SqlServerClient.1
            public void set(PreparedStatement preparedStatement, int i, Slice slice) throws SQLException {
                preparedStatement.setBytes(i, slice.getBytes());
            }

            public void setNull(PreparedStatement preparedStatement, int i) throws SQLException {
                preparedStatement.setBytes(i, null);
            }
        };
    }
}
