package tech.ydb.yoj.repository.ydb.statement;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.lang.reflect.Type;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.ydb.proto.ValueProtos;
import tech.ydb.yoj.databind.expression.FilterExpression;
import tech.ydb.yoj.databind.expression.OrderExpression;
import tech.ydb.yoj.databind.schema.ObjectSchema;
import tech.ydb.yoj.databind.schema.Schema;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.db.TableDescriptor;
import tech.ydb.yoj.repository.ydb.statement.Statement;
import tech.ydb.yoj.repository.ydb.yql.YqlListingQuery;
import tech.ydb.yoj.repository.ydb.yql.YqlPredicate;
import tech.ydb.yoj.repository.ydb.yql.YqlType;

/* loaded from: input_file:tech/ydb/yoj/repository/ydb/statement/FindInStatement.class */
public final class FindInStatement<IN, T extends Entity<T>, RESULT> extends MultipleVarsYqlStatement<IN, T, RESULT> {
    private static final Logger log = LoggerFactory.getLogger(FindInStatement.class);
    private final String indexName;
    private final Schema<?> keySchema;
    private final Set<String> keyFields;
    private final PredicateClause<T> predicate;
    private final OrderExpression<T> orderBy;
    private final Integer limit;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tech/ydb/yoj/repository/ydb/statement/FindInStatement$PredicateClause.class */
    public static class PredicateClause<T extends Entity<T>> extends PredicateStatement<Class<Void>, T, T> {
        private final YqlPredicate predicate;

        public PredicateClause(TableDescriptor<T> tableDescriptor, EntitySchema<T> entitySchema, YqlPredicate yqlPredicate) {
            super(tableDescriptor, entitySchema, entitySchema, Void.class, cls -> {
                return yqlPredicate;
            });
            this.predicate = yqlPredicate;
        }

        @Override // tech.ydb.yoj.repository.ydb.statement.Statement
        public Statement.QueryType getQueryType() {
            return Statement.QueryType.UNTYPED;
        }

        public String getClause() {
            return resolveParamNames(this.predicate.toFullYql(this.schema)) + "\n";
        }

        @Override // tech.ydb.yoj.repository.ydb.statement.Statement
        public String getQuery(String str) {
            return "SELECT 1";
        }

        public String toDebugString() {
            return toDebugString(Void.TYPE);
        }

        public Map<String, ValueProtos.TypedValue> toQueryParameters() {
            return toQueryParameters(Void.TYPE);
        }

        @Override // tech.ydb.yoj.repository.ydb.statement.Statement
        public String toDebugString(Class<Void> cls) {
            return this.predicate.toString();
        }
    }

    public static <ID extends Entity.Id<T>, T extends Entity<T>, RESULT> FindInStatement<Set<ID>, T, RESULT> from(TableDescriptor<T> tableDescriptor, EntitySchema<T> entitySchema, Schema<RESULT> schema, Iterable<ID> iterable, @Nullable FilterExpression<T> filterExpression, @Nullable OrderExpression<T> orderExpression, @Nullable Integer num) {
        return new FindInStatement<>(tableDescriptor, entitySchema, schema, entitySchema.getIdSchema(), collectKeyFieldsFromIds(entitySchema.getIdSchema(), iterable), null, filterExpression, orderExpression, num);
    }

    public static <K, T extends Entity<T>, RESULT> FindInStatement<Set<K>, T, RESULT> from(TableDescriptor<T> tableDescriptor, EntitySchema<T> entitySchema, Schema<RESULT> schema, String str, Iterable<K> iterable, @Nullable FilterExpression<T> filterExpression, @Nullable OrderExpression<T> orderExpression, @Nullable Integer num) {
        Schema keySchemaFromValues = getKeySchemaFromValues(iterable);
        return new FindInStatement<>(tableDescriptor, entitySchema, schema, keySchemaFromValues, collectKeyFieldsFromKeys(tableDescriptor, entitySchema, str, keySchemaFromValues, iterable), str, filterExpression, orderExpression, num);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <PARAMS> FindInStatement(TableDescriptor<T> tableDescriptor, EntitySchema<T> entitySchema, Schema<RESULT> schema, Schema<PARAMS> schema2, Set<String> set, @Nullable String str, @Nullable FilterExpression<T> filterExpression, @Nullable OrderExpression<T> orderExpression, @Nullable Integer num) {
        super(tableDescriptor, entitySchema, schema);
        this.indexName = str;
        this.orderBy = orderExpression;
        this.limit = num;
        this.keySchema = schema2;
        this.keyFields = set;
        if (filterExpression != null) {
            this.predicate = new PredicateClause<>(tableDescriptor, entitySchema, YqlListingQuery.toYqlPredicate(filterExpression));
        } else {
            this.predicate = null;
        }
        validateOrderByFields();
    }

    private static <T extends Entity<T>> Set<String> collectKeyFieldsFromIds(Schema<Entity.Id<T>> schema, Iterable<? extends Entity.Id<T>> iterable) {
        Preconditions.checkArgument(!Iterables.isEmpty(iterable), "ids should be non empty");
        Set set = (Set) Streams.stream(iterable).map(id -> {
            return nonNullKeyFieldNames(schema, id);
        }).collect(Collectors.toUnmodifiableSet());
        Preconditions.checkArgument(set.size() != 0, "ids should have at least one non-null field");
        Preconditions.checkArgument(set.size() == 1, "ids should have nulls in the same fields");
        Set<String> set2 = (Set) Iterables.getOnlyElement(set);
        if (!isPrefixedFields(schema.flattenFieldNames(), set2)) {
            log.warn("FindIn(ids) not by the primary key prefix will result in a FullScan, PK: {}, query uses the fields: {}", schema.flattenFieldNames(), set2);
        }
        return set2;
    }

    private static <V> Schema<V> getKeySchemaFromValues(Iterable<V> iterable) {
        Object first = Iterables.getFirst(iterable, (Object) null);
        Preconditions.checkArgument(first != null, "keys should be non empty");
        return ObjectSchema.of(first.getClass());
    }

    private static <E extends Entity<E>, K> Set<String> collectKeyFieldsFromKeys(TableDescriptor<E> tableDescriptor, Schema<E> schema, String str, Schema<K> schema2, Iterable<K> iterable) {
        Set set = (Set) Streams.stream(iterable).map(obj -> {
            return nonNullKeyFieldNames(schema2, obj);
        }).collect(Collectors.toUnmodifiableSet());
        Preconditions.checkArgument(set.size() != 0, "keys should have at least one non-null field");
        Preconditions.checkArgument(set.size() == 1, "keys should have nulls in the same fields");
        Set set2 = (Set) Iterables.getOnlyElement(set);
        Schema.Index index = (Schema.Index) schema.getGlobalIndexes().stream().filter(index2 -> {
            return str.equals(index2.getIndexName());
        }).findAny().orElseThrow(() -> {
            return new IllegalArgumentException("Table `%s` doesn't have index `%s`".formatted(tableDescriptor.toDebugString(), str));
        });
        Sets.SetView difference = Sets.difference(set2, Set.copyOf(index.getFieldNames()));
        Preconditions.checkArgument(difference.isEmpty(), "Index `%s` of table `%s` doesn't contain key(s): [%s]".formatted(str, tableDescriptor.toDebugString(), String.join(", ", (Iterable<? extends CharSequence>) difference)));
        Preconditions.checkArgument(isPrefixedFields(index.getFieldNames(), set2), "FindIn(keys) is allowed only by the prefix of the index key fields, index key: %s, query uses the fields: %s".formatted(index.getFieldNames(), set2));
        Map<String, Type> keyFieldTypeMap = getKeyFieldTypeMap(schema2, set2);
        Map<String, Type> keyFieldTypeMap2 = getKeyFieldTypeMap(schema, set2);
        for (Map.Entry<String, Type> entry : keyFieldTypeMap.entrySet()) {
            Type type = keyFieldTypeMap2.get(entry.getKey());
            Preconditions.checkArgument(type.equals(entry.getValue()), "Table `%s` has column `%s` of type `%s`, but corresponding key field is `%s`".formatted(tableDescriptor.toDebugString(), entry.getKey(), type, entry.getValue()));
        }
        return (Set) index.getFieldNames().stream().limit(set2.size()).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <V> Set<String> nonNullKeyFieldNames(Schema<V> schema, V v) {
        return schema.flatten(v).keySet();
    }

    private static boolean isPrefixedFields(List<String> list, Set<String> set) {
        Iterator<String> it = list.subList(0, set.size()).iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                return false;
            }
        }
        return true;
    }

    private static Map<String, Type> getKeyFieldTypeMap(Schema<?> schema, Set<String> set) {
        return (Map) schema.flattenFields().stream().filter(javaField -> {
            return set.contains(javaField.getName());
        }).collect(Collectors.toUnmodifiableMap((v0) -> {
            return v0.getName();
        }, (v0) -> {
            return v0.getType();
        }));
    }

    private void validateOrderByFields() {
        if (!hasOrderBy() || this.schema.equals(this.resultSchema)) {
            return;
        }
        Set set = (Set) this.resultSchema.flattenFields().stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toUnmodifiableSet());
        List list = this.orderBy.getKeys().stream().map((v0) -> {
            return v0.getField();
        }).flatMap((v0) -> {
            return v0.flatten();
        }).map((v0) -> {
            return v0.getName();
        }).filter(str -> {
            return !set.contains(str);
        }).toList();
        Preconditions.checkArgument(list.isEmpty(), "Result schema of '%s' does not contain field(s): [%s] by which the result is ordered: %s".formatted(this.resultSchema.getTypeName(), String.join(", ", list), this.orderBy));
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.Statement
    public Statement.QueryType getQueryType() {
        return Statement.QueryType.SELECT;
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.Statement
    public String getQuery(String str) {
        return declarations() + "SELECT " + outNames() + "\n" + (hasPredicate() ? "FROM (\nSELECT " + allColumnNames() + "\n" : "") + "FROM AS_TABLE($Input) AS k\nJOIN " + table(str) + indexUsage() + " AS t\nON " + joinExpression() + "\n" + (hasPredicate() ? ")\n" : "") + predicateClause() + orderByClause() + limitClause();
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement, tech.ydb.yoj.repository.ydb.statement.YqlStatement
    public List<YqlStatementParam> getParams() {
        return this.schema.flattenFields().stream().filter(javaField -> {
            return this.keyFields.contains(javaField.getName());
        }).map(javaField2 -> {
            return YqlStatementParam.required(YqlType.of(javaField2), javaField2.getName());
        }).toList();
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement, tech.ydb.yoj.repository.ydb.statement.YqlStatement, tech.ydb.yoj.repository.ydb.statement.Statement
    public Map<String, ValueProtos.TypedValue> toQueryParameters(IN in) {
        return hasPredicate() ? ImmutableMap.builder().putAll(super.toQueryParameters(in)).putAll(this.predicate.toQueryParameters()).build() : super.toQueryParameters(in);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement, tech.ydb.yoj.repository.ydb.statement.YqlStatement
    public String declarations() {
        return super.declarations() + predicateClauseDeclarations();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // tech.ydb.yoj.repository.ydb.statement.YqlStatement
    public String outNames() {
        return (String) this.resultSchema.flattenFields().stream().map(this::getOutName).collect(Collectors.joining(", "));
    }

    private String allColumnNames() {
        return (String) this.schema.flattenFields().stream().map(this::getAliasedName).collect(Collectors.joining(", "));
    }

    private String getOutName(Schema.JavaField javaField) {
        return hasPredicate() ? escape(javaField.getName()) : getAliasedName(javaField);
    }

    private String getAliasedName(Schema.JavaField javaField) {
        String escape = escape(javaField.getName());
        return "t." + escape + " AS " + escape;
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement
    protected Function<IN, Map<String, Object>> flattenInputVariables() {
        Schema<?> schema = this.keySchema;
        Objects.requireNonNull(schema);
        return schema::flatten;
    }

    private String indexUsage() {
        return isFindByIndex() ? " VIEW " + escape(this.indexName) : "";
    }

    private String joinExpression() {
        return (String) this.keyFields.stream().map(str -> {
            return "t.%1$s = k.%1$s".formatted(escape(str));
        }).collect(Collectors.joining(" AND "));
    }

    private String orderByClause() {
        return hasOrderBy() ? YqlListingQuery.toYqlOrderBy(this.orderBy).toFullYql(this.schema) + "\n" : "";
    }

    private String limitClause() {
        return hasLimit() ? "LIMIT " + this.limit + "\n" : "";
    }

    private String predicateClauseDeclarations() {
        return hasPredicate() ? this.predicate.declarations() : "";
    }

    private String predicateClause() {
        return hasPredicate() ? this.predicate.getClause() : "";
    }

    @Override // tech.ydb.yoj.repository.ydb.statement.Statement
    public String toDebugString(IN in) {
        return "findIn(" + toDebugParams(in) + (isFindByIndex() ? " by index " + escape(this.indexName) : "") + (hasPredicate() ? ", filter [" + this.predicate.toDebugString() + "]" : "") + (hasOrderBy() ? ", orderBy [" + this.orderBy + "]" : "") + (hasLimit() ? ", limit [" + this.limit + "]" : "") + ")";
    }

    private boolean isFindByIndex() {
        return this.indexName != null;
    }

    private boolean hasLimit() {
        return this.limit != null;
    }

    private boolean hasOrderBy() {
        return this.orderBy != null;
    }

    private boolean hasPredicate() {
        return this.predicate != null;
    }
}
