package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.CallConverters;
import io.substrait.isthmus.expression.LiteralConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.ImmutableFetch;
import io.substrait.relation.ImmutableMeasure;
import io.substrait.relation.ImmutableSort;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Value.Enclosing
/* loaded from: input_file:io/substrait/isthmus/SubstraitRelVisitor.class */
public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
    static final Logger logger = LoggerFactory.getLogger(SubstraitRelVisitor.class);
    private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
    private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);
    private final RexExpressionConverter converter;
    private final AggregateFunctionConverter aggregateFunctionConverter;
    private final TypeConverter typeConverter;
    private final FeatureBoard featureBoard;
    private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.substrait.isthmus.SubstraitRelVisitor$1, reason: invalid class name */
    /* loaded from: input_file:io/substrait/isthmus/SubstraitRelVisitor$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$calcite$rel$core$JoinRelType;
        static final /* synthetic */ int[] $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction = new int[RelFieldCollation.Direction.values().length];

        static {
            try {
                $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[RelFieldCollation.Direction.STRICTLY_ASCENDING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[RelFieldCollation.Direction.ASCENDING.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[RelFieldCollation.Direction.STRICTLY_DESCENDING.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[RelFieldCollation.Direction.DESCENDING.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[RelFieldCollation.Direction.CLUSTERED.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$org$apache$calcite$rel$core$JoinRelType = new int[JoinRelType.values().length];
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.INNER.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.LEFT.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.RIGHT.ordinal()] = 3;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.FULL.ordinal()] = 4;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.SEMI.ordinal()] = 5;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$apache$calcite$rel$core$JoinRelType[JoinRelType.ANTI.ordinal()] = 6;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    /* loaded from: input_file:io/substrait/isthmus/SubstraitRelVisitor$CrossJoinPolicy.class */
    public enum CrossJoinPolicy {
        KEEP_AS_CROSS_JOIN,
        CONVERT_TO_INNER_JOIN
    }

    /* loaded from: input_file:io/substrait/isthmus/SubstraitRelVisitor$Options.class */
    public static class Options {
        private final CrossJoinPolicy crossJoinPolicy;

        public Options() {
            this(CrossJoinPolicy.CONVERT_TO_INNER_JOIN);
        }

        public Options(CrossJoinPolicy crossJoinPolicy) {
            this.crossJoinPolicy = crossJoinPolicy;
        }

        public CrossJoinPolicy getCrossJoinPolicy() {
            return this.crossJoinPolicy;
        }
    }

    public SubstraitRelVisitor(RelDataTypeFactory relDataTypeFactory, SimpleExtension.ExtensionCollection extensionCollection) {
        this(relDataTypeFactory, extensionCollection, FEATURES_DEFAULT);
    }

    public SubstraitRelVisitor(RelDataTypeFactory relDataTypeFactory, SimpleExtension.ExtensionCollection extensionCollection, FeatureBoard featureBoard) {
        this.typeConverter = TypeConverter.DEFAULT;
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(CallConverters.defaults(this.typeConverter));
        arrayList.add(new ScalarFunctionConverter(extensionCollection.scalarFunctions(), relDataTypeFactory));
        arrayList.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(relDataTypeFactory)));
        this.aggregateFunctionConverter = new AggregateFunctionConverter(extensionCollection.aggregateFunctions(), relDataTypeFactory);
        this.converter = new RexExpressionConverter(this, arrayList, new WindowFunctionConverter(extensionCollection.windowFunctions(), relDataTypeFactory, this.aggregateFunctionConverter, this.typeConverter), this.typeConverter);
        this.featureBoard = featureBoard;
    }

    public SubstraitRelVisitor(RelDataTypeFactory relDataTypeFactory, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter, FeatureBoard featureBoard) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(CallConverters.defaults(typeConverter));
        arrayList.add(scalarFunctionConverter);
        arrayList.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(relDataTypeFactory)));
        this.aggregateFunctionConverter = aggregateFunctionConverter;
        this.converter = new RexExpressionConverter(this, arrayList, windowFunctionConverter, typeConverter);
        this.typeConverter = typeConverter;
        this.featureBoard = featureBoard;
    }

    private Expression toExpression(RexNode rexNode) {
        return (Expression) rexNode.accept(this.converter);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(TableScan tableScan) {
        return NamedScan.builder().initialSchema(this.typeConverter.toNamedStruct(tableScan.getRowType())).addAllNames(tableScan.getTable().getQualifiedName()).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(TableFunctionScan tableFunctionScan) {
        return (Rel) super.visit(tableFunctionScan);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalValues logicalValues) {
        NamedStruct namedStruct = this.typeConverter.toNamedStruct(logicalValues.getRowType());
        if (logicalValues.getTuples().isEmpty()) {
            return EmptyScan.builder().initialSchema(namedStruct).build();
        }
        LiteralConverter literalConverter = new LiteralConverter(this.typeConverter);
        return VirtualTableScan.builder().addAllDfsNames(namedStruct.names()).addAllRows((List) logicalValues.getTuples().stream().map(immutableList -> {
            return ExpressionCreator.struct(false, (List) immutableList.stream().map(rexLiteral -> {
                return literalConverter.convert(rexLiteral);
            }).collect(Collectors.toUnmodifiableList()));
        }).collect(Collectors.toUnmodifiableList())).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalFilter logicalFilter) {
        return Filter.builder().condition(toExpression(logicalFilter.getCondition())).input(apply(logicalFilter.getInput())).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalCalc logicalCalc) {
        return (Rel) super.visit(logicalCalc);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalProject logicalProject) {
        Rel apply = apply(logicalProject.getInput());
        this.converter.setInputRel(logicalProject.getInput());
        this.converter.setInputType(apply.getRecordType());
        List list = (List) logicalProject.getProjects().stream().map(this::toExpression).collect(Collectors.toList());
        this.converter.setInputRel(null);
        this.converter.setInputType(null);
        return Project.builder().remap(Rel.Remap.offset(logicalProject.getInput().getRowType().getFieldCount(), list.size())).expressions(list).input(apply(logicalProject.getInput())).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalJoin logicalJoin) {
        Join.JoinType joinType;
        Rel apply = apply(logicalJoin.getLeft());
        Rel apply2 = apply(logicalJoin.getRight());
        Expression expression = toExpression(logicalJoin.getCondition());
        switch (AnonymousClass1.$SwitchMap$org$apache$calcite$rel$core$JoinRelType[logicalJoin.getJoinType().ordinal()]) {
            case 1:
                joinType = Join.JoinType.INNER;
                break;
            case 2:
                joinType = Join.JoinType.LEFT;
                break;
            case 3:
                joinType = Join.JoinType.RIGHT;
                break;
            case 4:
                joinType = Join.JoinType.OUTER;
                break;
            case 5:
                joinType = Join.JoinType.SEMI;
                break;
            case 6:
                joinType = Join.JoinType.ANTI;
                break;
            default:
                throw new IncompatibleClassChangeError();
        }
        Join.JoinType joinType2 = joinType;
        return (joinType2 == Join.JoinType.INNER && TRUE.equals(expression) && this.featureBoard.crossJoinPolicy().equals(CrossJoinPolicy.KEEP_AS_CROSS_JOIN)) ? Cross.builder().left(apply).right(apply2).build() : Join.builder().condition(expression).joinType(joinType2).left(apply).right(apply2).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalCorrelate logicalCorrelate) {
        Join.JoinType joinType;
        apply(logicalCorrelate.getLeft());
        apply(logicalCorrelate.getRight());
        switch (AnonymousClass1.$SwitchMap$org$apache$calcite$rel$core$JoinRelType[logicalCorrelate.getJoinType().ordinal()]) {
            case 1:
                joinType = Join.JoinType.INNER;
                break;
            case 2:
                joinType = Join.JoinType.LEFT;
                break;
            default:
                throw new IllegalArgumentException("Invalid correlated join type: " + logicalCorrelate.getJoinType());
        }
        return (Rel) super.visit(logicalCorrelate);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalUnion logicalUnion) {
        return Set.builder().inputs(apply(logicalUnion.getInputs())).setOp(logicalUnion.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalIntersect logicalIntersect) {
        return Set.builder().inputs(apply(logicalIntersect.getInputs())).setOp(logicalIntersect.all ? Set.SetOp.INTERSECTION_MULTISET : Set.SetOp.INTERSECTION_PRIMARY).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalMinus logicalMinus) {
        return Set.builder().inputs(apply(logicalMinus.getInputs())).setOp(logicalMinus.all ? Set.SetOp.MINUS_MULTISET : Set.SetOp.MINUS_PRIMARY).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalAggregate logicalAggregate) {
        Rel apply = apply(logicalAggregate.getInput());
        return Aggregate.builder().input(apply).addAllGroupings((List) (logicalAggregate.groupSets != null ? logicalAggregate.groupSets.stream() : Stream.of(logicalAggregate.getGroupSet())).filter(immutableBitSet -> {
            return immutableBitSet != null;
        }).map(immutableBitSet2 -> {
            return fromGroupSet(immutableBitSet2, apply);
        }).collect(Collectors.toList())).addAllMeasures((List) logicalAggregate.getAggCallList().stream().map(aggregateCall -> {
            return fromAggCall(logicalAggregate.getInput(), apply.getRecordType(), aggregateCall);
        }).collect(Collectors.toList())).build();
    }

    Aggregate.Grouping fromGroupSet(ImmutableBitSet immutableBitSet, Rel rel) {
        return Aggregate.Grouping.builder().addAllExpressions((List) immutableBitSet.asList().stream().map(num -> {
            return FieldReference.newInputRelReference(num.intValue(), rel);
        }).collect(Collectors.toList())).build();
    }

    Aggregate.Measure fromAggCall(RelNode relNode, Type.Struct struct, AggregateCall aggregateCall) {
        Optional<AggregateFunctionInvocation> convert = this.aggregateFunctionConverter.convert(relNode, struct, aggregateCall, rexNode -> {
            return (Expression) rexNode.accept(this.converter);
        });
        if (convert.isEmpty()) {
            throw new UnsupportedOperationException("Unable to find binding for call " + aggregateCall);
        }
        ImmutableMeasure.Builder function = Aggregate.Measure.builder().function(convert.get());
        if (aggregateCall.filterArg != -1) {
            function.preMeasureFilter(FieldReference.newRootStructReference(aggregateCall.filterArg, struct));
        }
        return function.build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalMatch logicalMatch) {
        return (Rel) super.visit(logicalMatch);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalSort logicalSort) {
        Rel apply = apply(logicalSort.getInput());
        ImmutableSort build = Sort.builder().addAllSortFields((List) logicalSort.getCollation().getFieldCollations().stream().map(relFieldCollation -> {
            return toSortField(relFieldCollation, apply.getRecordType());
        }).collect(Collectors.toList())).input(apply).build();
        if (logicalSort.fetch == null && logicalSort.offset == null) {
            return build;
        }
        ImmutableFetch.Builder offset = Fetch.builder().input(build).offset(((Long) Optional.ofNullable(logicalSort.offset).map(rexNode -> {
            return Long.valueOf(asLong(rexNode));
        }).orElse(0L)).longValue());
        return logicalSort.fetch == null ? offset.build() : offset.count(asLong(logicalSort.fetch)).build();
    }

    private long asLong(RexNode rexNode) {
        Expression.I64Literal expression = toExpression(rexNode);
        if (expression instanceof Expression.I64Literal) {
            return expression.value();
        }
        if (expression instanceof Expression.I32Literal) {
            return ((Expression.I32Literal) expression).value();
        }
        throw new UnsupportedOperationException("Unknown type: " + rexNode);
    }

    public static Expression.SortField toSortField(RelFieldCollation relFieldCollation, Type.Struct struct) {
        Expression.SortDirection sortDirection;
        switch (AnonymousClass1.$SwitchMap$org$apache$calcite$rel$RelFieldCollation$Direction[relFieldCollation.direction.ordinal()]) {
            case 1:
            case 2:
                if (relFieldCollation.nullDirection != RelFieldCollation.NullDirection.LAST) {
                    sortDirection = Expression.SortDirection.ASC_NULLS_FIRST;
                    break;
                } else {
                    sortDirection = Expression.SortDirection.ASC_NULLS_LAST;
                    break;
                }
            case 3:
            case 4:
                if (relFieldCollation.nullDirection != RelFieldCollation.NullDirection.LAST) {
                    sortDirection = Expression.SortDirection.DESC_NULLS_FIRST;
                    break;
                } else {
                    sortDirection = Expression.SortDirection.DESC_NULLS_LAST;
                    break;
                }
            case 5:
                sortDirection = Expression.SortDirection.CLUSTERED;
                break;
            default:
                throw new IncompatibleClassChangeError();
        }
        return Expression.SortField.builder().expr(FieldReference.newRootStructReference(relFieldCollation.getFieldIndex(), struct)).direction(sortDirection).build();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalExchange logicalExchange) {
        return (Rel) super.visit(logicalExchange);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visit(LogicalTableModify logicalTableModify) {
        return (Rel) super.visit(logicalTableModify);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.substrait.isthmus.RelNodeVisitor
    public Rel visitOther(RelNode relNode) {
        throw new UnsupportedOperationException("Unable to handle node: " + relNode);
    }

    private void popFieldAccessDepthMap(RelNode relNode) {
        OuterReferenceResolver outerReferenceResolver = new OuterReferenceResolver();
        outerReferenceResolver.apply(relNode);
        this.fieldAccessDepthMap = outerReferenceResolver.getFieldAccessDepthMap();
    }

    public Integer getFieldAccessDepth(RexFieldAccess rexFieldAccess) {
        return this.fieldAccessDepthMap.get(rexFieldAccess);
    }

    public Rel apply(RelNode relNode) {
        return reverseAccept(relNode);
    }

    public List<Rel> apply(List<RelNode> list) {
        return (List) list.stream().map(relNode -> {
            return apply(relNode);
        }).collect(Collectors.toList());
    }

    public static Rel convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensionCollection) {
        return convert(relRoot.rel, extensionCollection, FEATURES_DEFAULT);
    }

    public static Rel convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensionCollection, FeatureBoard featureBoard) {
        return convert(relRoot.rel, extensionCollection, featureBoard);
    }

    private static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensionCollection, FeatureBoard featureBoard) {
        SubstraitRelVisitor substraitRelVisitor = new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensionCollection, featureBoard);
        substraitRelVisitor.popFieldAccessDepthMap(relNode);
        return substraitRelVisitor.apply(relNode);
    }
}
