package tools.refinery.store.reasoning.internal;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.BinaryOperator;
import tools.refinery.logic.Constraint;
import tools.refinery.logic.dnf.Dnf;
import tools.refinery.logic.dnf.DnfBuilder;
import tools.refinery.logic.dnf.DnfClause;
import tools.refinery.logic.literal.AbstractCallLiteral;
import tools.refinery.logic.literal.AbstractCountLiteral;
import tools.refinery.logic.literal.CallPolarity;
import tools.refinery.logic.literal.Literal;
import tools.refinery.logic.term.Aggregator;
import tools.refinery.logic.term.ConstantTerm;
import tools.refinery.logic.term.DataVariable;
import tools.refinery.logic.term.Parameter;
import tools.refinery.logic.term.Term;
import tools.refinery.logic.term.Variable;
import tools.refinery.logic.term.int_.IntTerms;
import tools.refinery.logic.term.uppercardinality.UpperCardinalities;
import tools.refinery.logic.term.uppercardinality.UpperCardinalityTerms;
import tools.refinery.store.reasoning.ReasoningAdapter;
import tools.refinery.store.reasoning.literal.Concreteness;
import tools.refinery.store.reasoning.literal.CountCandidateLowerBoundLiteral;
import tools.refinery.store.reasoning.literal.CountCandidateUpperBoundLiteral;
import tools.refinery.store.reasoning.literal.CountLowerBoundLiteral;
import tools.refinery.store.reasoning.literal.CountUpperBoundLiteral;
import tools.refinery.store.reasoning.literal.ModalConstraint;
import tools.refinery.store.reasoning.literal.Modality;
import tools.refinery.store.reasoning.representation.PartialRelation;
import tools.refinery.store.reasoning.translator.multiobject.MultiObjectTranslator;

/* loaded from: input_file:tools/refinery/store/reasoning/internal/PartialClauseRewriter.class */
class PartialClauseRewriter {
    private final PartialQueryRewriter rewriter;
    private final List<Literal> completedLiterals = new ArrayList();
    private final Deque<Literal> workList = new ArrayDeque();
    private final Set<Variable> positiveVariables = new LinkedHashSet();
    private final Set<Variable> unmodifiablePositiveVariables = Collections.unmodifiableSet(this.positiveVariables);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult.class */
    public static final class CountResult extends Record {
        private final DnfBuilder builder;
        private final List<Variable> rewrittenArguments;
        private final List<Variable> helperArguments;
        private final List<Variable> variablesToCount;

        private CountResult(DnfBuilder dnfBuilder, List<Variable> list, List<Variable> list2, List<Variable> list3) {
            this.builder = dnfBuilder;
            this.rewrittenArguments = list;
            this.helperArguments = list2;
            this.variablesToCount = list3;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, CountResult.class), CountResult.class, "builder;rewrittenArguments;helperArguments;variablesToCount", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->builder:Ltools/refinery/logic/dnf/DnfBuilder;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->rewrittenArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->helperArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->variablesToCount:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, CountResult.class), CountResult.class, "builder;rewrittenArguments;helperArguments;variablesToCount", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->builder:Ltools/refinery/logic/dnf/DnfBuilder;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->rewrittenArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->helperArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->variablesToCount:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, CountResult.class, Object.class), CountResult.class, "builder;rewrittenArguments;helperArguments;variablesToCount", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->builder:Ltools/refinery/logic/dnf/DnfBuilder;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->rewrittenArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->helperArguments:Ljava/util/List;", "FIELD:Ltools/refinery/store/reasoning/internal/PartialClauseRewriter$CountResult;->variablesToCount:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public DnfBuilder builder() {
            return this.builder;
        }

        public List<Variable> rewrittenArguments() {
            return this.rewrittenArguments;
        }

        public List<Variable> helperArguments() {
            return this.helperArguments;
        }

        public List<Variable> variablesToCount() {
            return this.variablesToCount;
        }
    }

    public PartialClauseRewriter(PartialQueryRewriter partialQueryRewriter) {
        this.rewriter = partialQueryRewriter;
    }

    public List<Literal> rewriteClause(DnfClause dnfClause) {
        this.workList.addAll(dnfClause.literals());
        while (!this.workList.isEmpty()) {
            rewrite(this.workList.removeFirst());
        }
        return this.completedLiterals;
    }

    private void rewrite(Literal literal) {
        if (!(literal instanceof AbstractCallLiteral)) {
            markAsDone(literal);
            return;
        }
        AbstractCallLiteral abstractCallLiteral = (AbstractCallLiteral) literal;
        Objects.requireNonNull(abstractCallLiteral);
        switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Object.class, Integer.TYPE), CountLowerBoundLiteral.class, CountUpperBoundLiteral.class, CountCandidateLowerBoundLiteral.class, CountCandidateUpperBoundLiteral.class).dynamicInvoker().invoke(abstractCallLiteral, 0) /* invoke-custom */) {
            case 0:
                rewriteCountLowerBound((CountLowerBoundLiteral) abstractCallLiteral);
                return;
            case 1:
                rewriteCountUpperBound((CountUpperBoundLiteral) abstractCallLiteral);
                return;
            case 2:
                rewriteCountCandidateLowerBound((CountCandidateLowerBoundLiteral) abstractCallLiteral);
                return;
            case 3:
                rewriteCountCandidateUpperBound((CountCandidateUpperBoundLiteral) abstractCallLiteral);
                return;
            default:
                Constraint target = abstractCallLiteral.getTarget();
                Objects.requireNonNull(target);
                switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Object.class, Integer.TYPE), Dnf.class, ModalConstraint.class).dynamicInvoker().invoke(target, 0) /* invoke-custom */) {
                    case 0:
                        rewriteRecursively(abstractCallLiteral, (Dnf) target);
                        return;
                    case 1:
                        ModalConstraint modalConstraint = (ModalConstraint) target;
                        Modality modality = modalConstraint.modality().toModality();
                        Concreteness concreteness = modalConstraint.concreteness().toConcreteness();
                        Constraint constraint = modalConstraint.constraint();
                        Objects.requireNonNull(constraint);
                        switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Object.class, Integer.TYPE), Dnf.class, PartialRelation.class).dynamicInvoker().invoke(constraint, 0) /* invoke-custom */) {
                            case 0:
                                rewriteRecursively(abstractCallLiteral, modality, concreteness, (Dnf) constraint);
                                return;
                            case 1:
                                rewrite(abstractCallLiteral, modality, concreteness, (PartialRelation) constraint);
                                return;
                            default:
                                throw new IllegalArgumentException("Cannot interpret modal constraint: " + String.valueOf(modalConstraint));
                        }
                    default:
                        markAsDone(literal);
                        return;
                }
        }
    }

    private void rewriteCountLowerBound(CountLowerBoundLiteral countLowerBoundLiteral) {
        rewritePartialCount(countLowerBoundLiteral, "lower", Modality.MUST, MultiObjectTranslator.LOWER_CARDINALITY_VIEW, 1, IntTerms::mul, IntTerms.INT_SUM);
    }

    private void rewriteCountUpperBound(CountUpperBoundLiteral countUpperBoundLiteral) {
        rewritePartialCount(countUpperBoundLiteral, "upper", Modality.MAY, MultiObjectTranslator.UPPER_CARDINALITY_VIEW, UpperCardinalities.ONE, UpperCardinalityTerms::mul, UpperCardinalityTerms.UPPER_CARDINALITY_SUM);
    }

    private <T> void rewritePartialCount(AbstractCountLiteral<T> abstractCountLiteral, String str, Modality modality, Constraint constraint, T t, BinaryOperator<Term<T>> binaryOperator, Aggregator<T, T> aggregator) {
        Class resultType = abstractCountLiteral.getResultType();
        CountResult computeCountVariables = computeCountVariables(abstractCountLiteral, Concreteness.PARTIAL, str);
        DnfBuilder builder = computeCountVariables.builder();
        Variable parameter = builder.parameter(resultType);
        List<Variable> variablesToCount = computeCountVariables.variablesToCount();
        ArrayList arrayList = new ArrayList();
        arrayList.add(ModalConstraint.of(modality, Concreteness.PARTIAL, abstractCountLiteral.getTarget()).call(CallPolarity.POSITIVE, computeCountVariables.rewrittenArguments()));
        switch (variablesToCount.size()) {
            case 0:
                arrayList.add(parameter.assign(new ConstantTerm(resultType, t)));
                break;
            case 1:
                arrayList.add(constraint.call(new Variable[]{(Variable) variablesToCount.getFirst(), parameter}));
                break;
            default:
                Term<T> of = Variable.of(resultType);
                arrayList.add(constraint.call(new Variable[]{(Variable) variablesToCount.getFirst(), of}));
                int size = variablesToCount.size();
                Term<T> term = of;
                for (int i = 1; i < size; i++) {
                    Variable of2 = Variable.of(resultType);
                    arrayList.add(constraint.call(new Variable[]{variablesToCount.get(i), of2}));
                    term = (Term) binaryOperator.apply(term, of2);
                }
                arrayList.add(parameter.assign(term));
                break;
        }
        builder.clause(arrayList);
        Dnf build = builder.build();
        DataVariable of3 = Variable.of(resultType);
        List<Variable> helperArguments = computeCountVariables.helperArguments();
        helperArguments.add(of3);
        this.workList.addFirst(abstractCountLiteral.getResultVariable().assign(build.aggregateBy(of3, aggregator, helperArguments)));
    }

    private void rewriteCountCandidateLowerBound(CountCandidateLowerBoundLiteral countCandidateLowerBoundLiteral) {
        rewriteCandidateCount(countCandidateLowerBoundLiteral, "lower", Modality.MUST);
    }

    private void rewriteCountCandidateUpperBound(CountCandidateUpperBoundLiteral countCandidateUpperBoundLiteral) {
        rewriteCandidateCount(countCandidateUpperBoundLiteral, "upper", Modality.MAY);
    }

    private void rewriteCandidateCount(AbstractCountLiteral<Integer> abstractCountLiteral, String str, Modality modality) {
        CountResult computeCountVariables = computeCountVariables(abstractCountLiteral, Concreteness.CANDIDATE, str);
        DnfBuilder builder = computeCountVariables.builder();
        ArrayList arrayList = new ArrayList();
        arrayList.add(ModalConstraint.of(modality, Concreteness.CANDIDATE, abstractCountLiteral.getTarget()).call(CallPolarity.POSITIVE, computeCountVariables.rewrittenArguments()));
        Iterator<Variable> it = computeCountVariables.variablesToCount().iterator();
        while (it.hasNext()) {
            arrayList.add(ModalConstraint.of(modality, Concreteness.CANDIDATE, ReasoningAdapter.EXISTS_SYMBOL).call(new Variable[]{it.next()}));
        }
        builder.clause(arrayList);
        this.workList.addFirst(abstractCountLiteral.getResultVariable().assign(builder.build().count(computeCountVariables.helperArguments())));
    }

    private CountResult computeCountVariables(AbstractCountLiteral<?> abstractCountLiteral, Concreteness concreteness, String str) {
        Constraint target = abstractCountLiteral.getTarget();
        int arity = target.arity();
        List parameters = target.getParameters();
        List arguments = abstractCountLiteral.getArguments();
        Set privateVariables = abstractCountLiteral.getPrivateVariables(this.positiveVariables);
        DnfBuilder builder = Dnf.builder("%s#%s#%s".formatted(target.name(), concreteness, str));
        ArrayList arrayList = new ArrayList(parameters.size());
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < arity; i++) {
            Variable variable = (Variable) arguments.get(i);
            Parameter parameter = (Parameter) parameters.get(i);
            arrayList.add((Variable) hashMap.computeIfAbsent(variable, variable2 -> {
                arrayList3.add(variable2);
                Variable parameter2 = builder.parameter(parameter);
                if (privateVariables.contains(variable2)) {
                    arrayList2.add(parameter2);
                }
                return parameter2;
            }));
        }
        return new CountResult(builder, arrayList, arrayList3, arrayList2);
    }

    private void markAsDone(Literal literal) {
        this.completedLiterals.add(literal);
        this.positiveVariables.addAll(literal.getOutputVariables());
    }

    private void rewriteRecursively(AbstractCallLiteral abstractCallLiteral, Modality modality, Concreteness concreteness, Dnf dnf) {
        rewriteRecursively(abstractCallLiteral, this.rewriter.getLifter().lift(modality, concreteness, dnf));
    }

    private void rewriteRecursively(AbstractCallLiteral abstractCallLiteral, Dnf dnf) {
        markAsDone(abstractCallLiteral.withTarget(this.rewriter.rewrite(dnf)));
    }

    private void rewrite(AbstractCallLiteral abstractCallLiteral, Modality modality, Concreteness concreteness, PartialRelation partialRelation) {
        List<Literal> rewriteLiteral = this.rewriter.getRelationRewriter(partialRelation).rewriteLiteral(this.unmodifiablePositiveVariables, abstractCallLiteral, modality, concreteness);
        for (int size = rewriteLiteral.size() - 1; size >= 0; size--) {
            this.workList.addFirst(rewriteLiteral.get(size));
        }
    }
}
