package io.substrait.expression.proto;

import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.Expression;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/substrait/expression/proto/ProtoExpressionConverter.class */
public class ProtoExpressionConverter {
    static final Logger logger = LoggerFactory.getLogger(ProtoExpressionConverter.class);
    public static final Type.Struct EMPTY_TYPE = Type.Struct.builder().nullable(false).build();
    private final ExtensionLookup lookup;
    private final SimpleExtension.ExtensionCollection extensions;
    private final Type.Struct rootType;
    private final ProtoTypeConverter protoTypeConverter;
    private final ProtoRelConverter protoRelConverter;

    public ProtoExpressionConverter(ExtensionLookup extensionLookup, SimpleExtension.ExtensionCollection extensionCollection, Type.Struct struct, ProtoRelConverter protoRelConverter) {
        this.lookup = extensionLookup;
        this.extensions = extensionCollection;
        this.rootType = (Type.Struct) Objects.requireNonNull(struct, "rootType");
        this.protoTypeConverter = new ProtoTypeConverter(extensionLookup, extensionCollection);
        this.protoRelConverter = protoRelConverter;
    }

    public FieldReference from(Expression.FieldReference fieldReference) {
        FieldReference newRootStructOuterReference;
        Object of;
        switch (fieldReference.getReferenceTypeCase()) {
            case DIRECT_REFERENCE:
                Expression.ReferenceSegment directReference = fieldReference.getDirectReference();
                ArrayList arrayList = new ArrayList();
                while (directReference != Expression.ReferenceSegment.getDefaultInstance()) {
                    switch (directReference.getReferenceTypeCase()) {
                        case MAP_KEY:
                            Expression.ReferenceSegment.MapKey mapKey = directReference.getMapKey();
                            directReference = mapKey.getChild();
                            of = FieldReference.MapKey.of(from(mapKey.getMapKey()));
                            break;
                        case STRUCT_FIELD:
                            Expression.ReferenceSegment.StructField structField = directReference.getStructField();
                            directReference = structField.getChild();
                            of = FieldReference.StructField.of(structField.getField());
                            break;
                        case LIST_ELEMENT:
                            Expression.ReferenceSegment.ListElement listElement = directReference.getListElement();
                            directReference = listElement.getChild();
                            of = FieldReference.ListElement.of(listElement.getOffset());
                            break;
                        case REFERENCETYPE_NOT_SET:
                            throw new IllegalArgumentException("Unhandled type: " + directReference.getReferenceTypeCase());
                        default:
                            throw new IncompatibleClassChangeError();
                    }
                    arrayList.add(of);
                }
                Collections.reverse(arrayList);
                switch (fieldReference.getRootTypeCase()) {
                    case EXPRESSION:
                        newRootStructOuterReference = FieldReference.ofExpression(from(fieldReference.getExpression()), arrayList);
                        break;
                    case ROOT_REFERENCE:
                        newRootStructOuterReference = FieldReference.ofRoot(this.rootType, arrayList);
                        break;
                    case OUTER_REFERENCE:
                        newRootStructOuterReference = FieldReference.newRootStructOuterReference(fieldReference.getDirectReference().getStructField().getField(), this.rootType, fieldReference.getOuterReference().getStepsOut());
                        break;
                    case ROOTTYPE_NOT_SET:
                        throw new IllegalArgumentException("Unhandled type: " + fieldReference.getRootTypeCase());
                    default:
                        throw new IncompatibleClassChangeError();
                }
                return newRootStructOuterReference;
            case MASKED_REFERENCE:
                throw new IllegalArgumentException("Unsupported type: " + fieldReference.getReferenceTypeCase());
            default:
                throw new IllegalArgumentException("Unhandled type: " + fieldReference.getReferenceTypeCase());
        }
    }

    public io.substrait.expression.Expression from(Expression expression) {
        switch (expression.getRexTypeCase()) {
            case LITERAL:
                return from(expression.getLiteral());
            case SELECTION:
                return from(expression.getSelection());
            case SCALAR_FUNCTION:
                Expression.ScalarFunction scalarFunction = expression.getScalarFunction();
                SimpleExtension.ScalarFunctionVariant scalarFunction2 = this.lookup.getScalarFunction(scalarFunction.getFunctionReference(), this.extensions);
                FunctionArg.ProtoFrom protoFrom = new FunctionArg.ProtoFrom(this, this.protoTypeConverter);
                return ImmutableExpression.ScalarFunctionInvocation.builder().addAllArguments((List) IntStream.range(0, scalarFunction.getArgumentsCount()).mapToObj(i -> {
                    return protoFrom.convert(scalarFunction2, i, scalarFunction.getArguments(i));
                }).collect(Collectors.toList())).declaration(scalarFunction2).outputType(this.protoTypeConverter.from(scalarFunction.getOutputType())).build();
            case WINDOW_FUNCTION:
                Expression.WindowFunction windowFunction = expression.getWindowFunction();
                SimpleExtension.WindowFunctionVariant windowFunction2 = this.lookup.getWindowFunction(windowFunction.getFunctionReference(), this.extensions);
                FunctionArg.ProtoFrom protoFrom2 = new FunctionArg.ProtoFrom(this, this.protoTypeConverter);
                List list = (List) IntStream.range(0, windowFunction.getArgumentsCount()).mapToObj(i2 -> {
                    return protoFrom2.convert(windowFunction2, i2, windowFunction.getArguments(i2));
                }).collect(Collectors.toList());
                List list2 = (List) windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList());
                return Expression.WindowFunctionInvocation.builder().arguments(list).declaration(windowFunction2).outputType(this.protoTypeConverter.from(windowFunction.getOutputType())).aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())).partitionBy(list2).sort((List) windowFunction.getSortsList().stream().map(sortField -> {
                    return Expression.SortField.builder().direction(Expression.SortDirection.fromProto(sortField.getDirection())).expr(from(sortField.getExpr())).build();
                }).collect(Collectors.toList())).boundsType(Expression.WindowBoundsType.fromProto(windowFunction.getBoundsType())).lowerBound(toWindowBound(windowFunction.getLowerBound())).upperBound(toWindowBound(windowFunction.getUpperBound())).invocation(Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())).build();
            case IF_THEN:
                Expression.IfThen ifThen = expression.getIfThen();
                return ExpressionCreator.ifThenStatement(from(ifThen.getElse()), (List) ifThen.getIfsList().stream().map(ifClause -> {
                    return ExpressionCreator.ifThenClause(from(ifClause.getIf()), from(ifClause.getThen()));
                }).collect(Collectors.toList()));
            case SWITCH_EXPRESSION:
                Expression.SwitchExpression switchExpression = expression.getSwitchExpression();
                return ExpressionCreator.switchStatement(from(switchExpression.getMatch()), from(switchExpression.getElse()), (List) switchExpression.getIfsList().stream().map(ifValue -> {
                    return ExpressionCreator.switchClause(from(ifValue.getIf()), from(ifValue.getThen()));
                }).collect(Collectors.toList()));
            case SINGULAR_OR_LIST:
                Expression.SingularOrList singularOrList = expression.getSingularOrList();
                return ImmutableExpression.SingleOrList.builder().condition(from(singularOrList.getValue())).addAllOptions((List) singularOrList.getOptionsList().stream().map(this::from).collect(Collectors.toList())).build();
            case MULTI_OR_LIST:
                Expression.MultiOrList multiOrList = expression.getMultiOrList();
                return ImmutableExpression.MultiOrList.builder().addAllOptionCombinations((List) multiOrList.getOptionsList().stream().map(record -> {
                    return ImmutableExpression.MultiOrListRecord.builder().addAllValues((Iterable) record.getFieldsList().stream().map(this::from).collect(Collectors.toList())).build();
                }).collect(Collectors.toList())).addAllConditions((Iterable) multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList())).build();
            case CAST:
                return ExpressionCreator.cast(this.protoTypeConverter.from(expression.getCast().getType()), from(expression.getCast().getInput()));
            case SUBQUERY:
                switch (expression.getSubquery().getSubqueryTypeCase()) {
                    case SET_PREDICATE:
                        return ImmutableExpression.SetPredicate.builder().tuples(this.protoRelConverter.from(expression.getSubquery().getSetPredicate().getTuples())).predicateOp(Expression.PredicateOp.fromProto(expression.getSubquery().getSetPredicate().getPredicateOp())).build();
                    case SCALAR:
                        Rel from = this.protoRelConverter.from(expression.getSubquery().getScalar().getInput());
                        return ImmutableExpression.ScalarSubquery.builder().input(from).type(from.getRecordType()).build();
                    case IN_PREDICATE:
                        return ImmutableExpression.InPredicate.builder().haystack(this.protoRelConverter.from(expression.getSubquery().getInPredicate().getHaystack())).needles((List) expression.getSubquery().getInPredicate().getNeedlesList().stream().map(expression2 -> {
                            return from(expression2);
                        }).collect(Collectors.toList())).build();
                    case SET_COMPARISON:
                        throw new UnsupportedOperationException("Unsupported subquery type: " + expression.getSubquery().getSubqueryTypeCase());
                    default:
                        throw new IllegalArgumentException("Unknown subquery type: " + expression.getSubquery().getSubqueryTypeCase());
                }
            case ENUM:
                throw new UnsupportedOperationException("Unsupported type: " + expression.getRexTypeCase());
            default:
                throw new IllegalArgumentException("Unknown type: " + expression.getRexTypeCase());
        }
    }

    private WindowBound toWindowBound(Expression.WindowFunction.Bound bound) {
        switch (bound.getKindCase()) {
            case PRECEDING:
                return WindowBound.Preceding.of(bound.getPreceding().getOffset());
            case FOLLOWING:
                return WindowBound.Following.of(bound.getFollowing().getOffset());
            case CURRENT_ROW:
                return WindowBound.CURRENT_ROW;
            case UNBOUNDED:
                return WindowBound.UNBOUNDED;
            case KIND_NOT_SET:
                return WindowBound.UNBOUNDED;
            default:
                throw new IncompatibleClassChangeError();
        }
    }

    public Expression.Literal from(Expression.Literal literal) {
        switch (literal.getLiteralTypeCase()) {
            case BOOLEAN:
                return ExpressionCreator.bool(literal.getNullable(), literal.getBoolean());
            case I8:
                return ExpressionCreator.i8(literal.getNullable(), literal.getI8());
            case I16:
                return ExpressionCreator.i16(literal.getNullable(), literal.getI16());
            case I32:
                return ExpressionCreator.i32(literal.getNullable(), literal.getI32());
            case I64:
                return ExpressionCreator.i64(literal.getNullable(), literal.getI64());
            case FP32:
                return ExpressionCreator.fp32(literal.getNullable(), literal.getFp32());
            case FP64:
                return ExpressionCreator.fp64(literal.getNullable(), literal.getFp64());
            case STRING:
                return ExpressionCreator.string(literal.getNullable(), literal.getString());
            case BINARY:
                return ExpressionCreator.binary(literal.getNullable(), literal.getBinary());
            case TIMESTAMP:
                return ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp());
            case DATE:
                return ExpressionCreator.date(literal.getNullable(), literal.getDate());
            case TIME:
                return ExpressionCreator.time(literal.getNullable(), literal.getTime());
            case INTERVAL_YEAR_TO_MONTH:
                return ExpressionCreator.intervalYear(literal.getNullable(), literal.getIntervalYearToMonth().getYears(), literal.getIntervalYearToMonth().getMonths());
            case INTERVAL_DAY_TO_SECOND:
                return ExpressionCreator.intervalDay(literal.getNullable(), literal.getIntervalDayToSecond().getDays(), literal.getIntervalDayToSecond().getSeconds(), literal.getIntervalDayToSecond().getMicroseconds());
            case FIXED_CHAR:
                return ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar());
            case VAR_CHAR:
                return ExpressionCreator.varChar(literal.getNullable(), literal.getVarChar().getValue(), literal.getVarChar().getLength());
            case FIXED_BINARY:
                return ExpressionCreator.fixedBinary(literal.getNullable(), literal.getFixedBinary());
            case DECIMAL:
                return ExpressionCreator.decimal(literal.getNullable(), literal.getDecimal().getValue(), literal.getDecimal().getPrecision(), literal.getDecimal().getScale());
            case STRUCT:
                return ExpressionCreator.struct(literal.getNullable(), (Iterable<? extends Expression.Literal>) literal.getStruct().getFieldsList().stream().map(this::from).collect(Collectors.toList()));
            case MAP:
                return ExpressionCreator.map(literal.getNullable(), (Map) literal.getMap().getKeyValuesList().stream().collect(Collectors.toMap(keyValue -> {
                    return from(keyValue.getKey());
                }, keyValue2 -> {
                    return from(keyValue2.getValue());
                })));
            case TIMESTAMP_TZ:
                return ExpressionCreator.timestampTZ(literal.getNullable(), literal.getTimestampTz());
            case UUID:
                return ExpressionCreator.uuid(literal.getNullable(), literal.getUuid());
            case NULL:
                return ExpressionCreator.typedNull(this.protoTypeConverter.from(literal.getNull()));
            case LIST:
                return ExpressionCreator.list(literal.getNullable(), (Iterable<? extends Expression.Literal>) literal.getList().getValuesList().stream().map(this::from).collect(Collectors.toList()));
            default:
                throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
        }
    }
}
