package tools.refinery.logic.dnf;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Function;
import org.jetbrains.annotations.NotNull;
import tools.refinery.logic.InvalidQueryException;
import tools.refinery.logic.literal.BooleanLiteral;
import tools.refinery.logic.literal.CallLiteral;
import tools.refinery.logic.literal.CallPolarity;
import tools.refinery.logic.literal.EquivalenceLiteral;
import tools.refinery.logic.literal.Literal;
import tools.refinery.logic.substitution.MapBasedSubstitution;
import tools.refinery.logic.substitution.StatelessSubstitution;
import tools.refinery.logic.term.ParameterDirection;
import tools.refinery.logic.term.Variable;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor.class */
public class ClausePostProcessor {
    private final Map<Variable, ParameterInfo> parameters;
    private final List<Literal> literals;
    private List<Literal> substitutedLiterals;
    private Set<Variable> positiveVariables;
    private Map<Variable, Set<SortableLiteral>> variableToLiteralInputMap;
    private PriorityQueue<SortableLiteral> literalsWithAllInputsBound;
    private LinkedHashSet<Literal> topologicallySortedLiterals;
    private final Map<Variable, Variable> representatives = new LinkedHashMap();
    private final Map<Variable, Set<Variable>> equivalencePartition = new HashMap();
    private final Set<Variable> existentiallyQuantifiedVariables = new LinkedHashSet();

    /* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor$ClauseResult.class */
    public static final class ClauseResult extends Record implements Result {
        private final DnfClause clause;

        public ClauseResult(DnfClause dnfClause) {
            this.clause = dnfClause;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ClauseResult.class), ClauseResult.class, "clause", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ClauseResult;->clause:Ltools/refinery/logic/dnf/DnfClause;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ClauseResult.class), ClauseResult.class, "clause", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ClauseResult;->clause:Ltools/refinery/logic/dnf/DnfClause;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ClauseResult.class, Object.class), ClauseResult.class, "clause", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ClauseResult;->clause:Ltools/refinery/logic/dnf/DnfClause;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public DnfClause clause() {
            return this.clause;
        }
    }

    /* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor$ConstantResult.class */
    public enum ConstantResult implements Result {
        ALWAYS_TRUE,
        ALWAYS_FALSE
    }

    /* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo.class */
    public static final class ParameterInfo extends Record {
        private final ParameterDirection direction;
        private final int index;

        public ParameterInfo(ParameterDirection parameterDirection, int i) {
            this.direction = parameterDirection;
            this.index = i;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ParameterInfo.class), ParameterInfo.class, "direction;index", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->direction:Ltools/refinery/logic/term/ParameterDirection;", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->index:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ParameterInfo.class), ParameterInfo.class, "direction;index", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->direction:Ltools/refinery/logic/term/ParameterDirection;", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->index:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ParameterInfo.class, Object.class), ParameterInfo.class, "direction;index", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->direction:Ltools/refinery/logic/term/ParameterDirection;", "FIELD:Ltools/refinery/logic/dnf/ClausePostProcessor$ParameterInfo;->index:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public ParameterDirection direction() {
            return this.direction;
        }

        public int index() {
            return this.index;
        }
    }

    /* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor$Result.class */
    public interface Result {
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tools/refinery/logic/dnf/ClausePostProcessor$SortableLiteral.class */
    public class SortableLiteral implements Comparable<SortableLiteral> {
        private final int index;
        private final Literal literal;
        private final Set<Variable> remainingInputs;

        private SortableLiteral(int i, Literal literal) {
            this.index = i;
            this.literal = literal;
            this.remainingInputs = new HashSet(literal.getInputVariables(ClausePostProcessor.this.positiveVariables));
            for (Map.Entry<Variable, ParameterInfo> entry : ClausePostProcessor.this.parameters.entrySet()) {
                if (entry.getValue().direction() == ParameterDirection.IN) {
                    this.remainingInputs.remove(entry.getKey());
                }
            }
        }

        public void enqueue() {
            if (allInputsBound()) {
                addToAllInputsBoundQueue();
            } else {
                addToVariableToLiteralInputMap();
            }
        }

        private void bindVariable(Variable variable) {
            if (!this.remainingInputs.remove(variable)) {
                throw new AssertionError("Already processed input %s of literal %s".formatted(variable, this.literal));
            }
            if (allInputsBound()) {
                addToAllInputsBoundQueue();
            }
        }

        private boolean allInputsBound() {
            return this.remainingInputs.isEmpty();
        }

        private void addToVariableToLiteralInputMap() {
            Iterator<Variable> it = this.remainingInputs.iterator();
            while (it.hasNext()) {
                ClausePostProcessor.this.variableToLiteralInputMap.computeIfAbsent(it.next(), variable -> {
                    return new HashSet();
                }).add(this);
            }
        }

        private void addToAllInputsBoundQueue() {
            ClausePostProcessor.this.literalsWithAllInputsBound.add(this);
        }

        public void addToSortedLiterals() {
            if (!allInputsBound()) {
                throw new AssertionError("Inputs %s of %s are not yet bound".formatted(this.remainingInputs, this.literal));
            }
            ClausePostProcessor.this.topologicallySortedLiterals.add(this.literal);
            for (Variable variable : this.literal.getOutputVariables()) {
                Set<SortableLiteral> remove = ClausePostProcessor.this.variableToLiteralInputMap.remove(variable);
                if (remove != null) {
                    Iterator<SortableLiteral> it = remove.iterator();
                    while (it.hasNext()) {
                        it.next().bindVariable(variable);
                    }
                }
            }
        }

        @Override // java.lang.Comparable
        public int compareTo(@NotNull SortableLiteral sortableLiteral) {
            return Integer.compare(this.index, sortableLiteral.index);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            SortableLiteral sortableLiteral = (SortableLiteral) obj;
            return this.index == sortableLiteral.index && Objects.equals(this.literal, sortableLiteral.literal);
        }

        public int hashCode() {
            return Objects.hash(Integer.valueOf(this.index), this.literal);
        }
    }

    public ClausePostProcessor(Map<Variable, ParameterInfo> map, List<Literal> list) {
        this.parameters = map;
        this.literals = list;
    }

    public Result postProcessClause() {
        mergeEquivalentNodeVariables();
        this.substitutedLiterals = new ArrayList(this.literals.size());
        keepParameterEquivalences();
        substituteLiterals();
        computeExistentiallyQuantifiedVariables();
        computePositiveVariables();
        validatePositiveRepresentatives();
        validatePrivateVariables();
        topologicallySortLiterals();
        ArrayList arrayList = new ArrayList(this.topologicallySortedLiterals.size());
        Iterator<Literal> it = this.topologicallySortedLiterals.iterator();
        while (it.hasNext()) {
            Literal reduce = it.next().reduce();
            if (BooleanLiteral.FALSE.equals(reduce)) {
                return ConstantResult.ALWAYS_FALSE;
            }
            if (!BooleanLiteral.TRUE.equals(reduce)) {
                arrayList.add(reduce);
            }
        }
        return arrayList.isEmpty() ? ConstantResult.ALWAYS_TRUE : hasContradictoryCall(arrayList) ? ConstantResult.ALWAYS_FALSE : new ClauseResult(new DnfClause(Collections.unmodifiableSet(this.positiveVariables), Collections.unmodifiableList(arrayList)));
    }

    private void mergeEquivalentNodeVariables() {
        for (Literal literal : this.literals) {
            if (isPositiveEquivalence(literal)) {
                EquivalenceLiteral equivalenceLiteral = (EquivalenceLiteral) literal;
                mergeVariables(equivalenceLiteral.getLeft(), equivalenceLiteral.getRight());
            }
        }
    }

    private static boolean isPositiveEquivalence(Literal literal) {
        return (literal instanceof EquivalenceLiteral) && ((EquivalenceLiteral) literal).isPositive();
    }

    private void mergeVariables(Variable variable, Variable variable2) {
        Variable representative = getRepresentative(variable);
        Variable representative2 = getRepresentative(variable2);
        ParameterInfo parameterInfo = this.parameters.get(representative);
        ParameterInfo parameterInfo2 = this.parameters.get(representative2);
        if (parameterInfo == null || (parameterInfo2 != null && parameterInfo.index() > parameterInfo2.index())) {
            doMergeVariables(representative2, representative);
        } else {
            doMergeVariables(representative, representative2);
        }
    }

    private void doMergeVariables(Variable variable, Variable variable2) {
        Set<Variable> equivalentVariables = getEquivalentVariables(variable);
        Set<Variable> equivalentVariables2 = getEquivalentVariables(variable2);
        equivalentVariables.addAll(equivalentVariables2);
        this.equivalencePartition.remove(variable2);
        Iterator<Variable> it = equivalentVariables2.iterator();
        while (it.hasNext()) {
            this.representatives.put(it.next(), variable);
        }
    }

    private Variable getRepresentative(Variable variable) {
        return this.representatives.computeIfAbsent(variable, Function.identity());
    }

    private Set<Variable> getEquivalentVariables(Variable variable) {
        Variable representative = getRepresentative(variable);
        if (representative.equals(variable)) {
            return this.equivalencePartition.computeIfAbsent(variable, variable2 -> {
                HashSet newHashSet = HashSet.newHashSet(1);
                newHashSet.add(variable2);
                return newHashSet;
            });
        }
        throw new AssertionError("NodeVariable %s already has a representative %s".formatted(variable, representative));
    }

    private void keepParameterEquivalences() {
        for (Map.Entry<Variable, Variable> entry : this.representatives.entrySet()) {
            Variable key = entry.getKey();
            Variable value = entry.getValue();
            if (!key.equals(value) && this.parameters.containsKey(key) && this.parameters.containsKey(value)) {
                this.substitutedLiterals.add(new EquivalenceLiteral(true, key, value));
            }
        }
    }

    private void substituteLiterals() {
        MapBasedSubstitution mapBasedSubstitution = this.representatives.isEmpty() ? null : new MapBasedSubstitution(Collections.unmodifiableMap(this.representatives), StatelessSubstitution.IDENTITY);
        for (Literal literal : this.literals) {
            if (!isPositiveEquivalence(literal)) {
                this.substitutedLiterals.add(mapBasedSubstitution == null ? literal : literal.substitute(mapBasedSubstitution));
            }
        }
    }

    private void computeExistentiallyQuantifiedVariables() {
        Iterator<Literal> it = this.substitutedLiterals.iterator();
        while (it.hasNext()) {
            this.existentiallyQuantifiedVariables.addAll(it.next().getOutputVariables());
        }
    }

    private void computePositiveVariables() {
        this.positiveVariables = new LinkedHashSet();
        for (Map.Entry<Variable, ParameterInfo> entry : this.parameters.entrySet()) {
            Variable key = entry.getKey();
            if (entry.getValue().direction() == ParameterDirection.IN) {
                this.positiveVariables.add(key);
            } else if (!this.existentiallyQuantifiedVariables.contains(key)) {
                throw new InvalidQueryException("Unbound %s parameter %s".formatted(ParameterDirection.OUT, key));
            }
        }
        this.positiveVariables.addAll(this.existentiallyQuantifiedVariables);
    }

    private void validatePositiveRepresentatives() {
        for (Map.Entry<Variable, Set<Variable>> entry : this.equivalencePartition.entrySet()) {
            if (!this.positiveVariables.contains(entry.getKey())) {
                throw new InvalidQueryException("Variables %s were merged by equivalence but are not bound".formatted(entry.getValue()));
            }
        }
    }

    private void validatePrivateVariables() {
        HashMap hashMap = new HashMap();
        for (Literal literal : this.substitutedLiterals) {
            for (Variable variable : literal.getPrivateVariables(this.positiveVariables)) {
                Literal literal2 = (Literal) hashMap.put(variable, literal);
                if (literal2 != null) {
                    throw new InvalidQueryException("Unbound variable %s appears in multiple literals %s and %s".formatted(variable, literal2, literal));
                }
            }
        }
    }

    private void topologicallySortLiterals() {
        this.topologicallySortedLiterals = LinkedHashSet.newLinkedHashSet(this.substitutedLiterals.size());
        this.variableToLiteralInputMap = new HashMap();
        this.literalsWithAllInputsBound = new PriorityQueue<>();
        int size = this.substitutedLiterals.size();
        for (int i = 0; i < size; i++) {
            new SortableLiteral(i, this.substitutedLiterals.get(i)).enqueue();
        }
        while (!this.literalsWithAllInputsBound.isEmpty()) {
            this.literalsWithAllInputsBound.remove().addToSortedLiterals();
        }
        if (!this.variableToLiteralInputMap.isEmpty()) {
            throw new InvalidQueryException("Unbound input variables %s".formatted(this.variableToLiteralInputMap.keySet()));
        }
    }

    private boolean hasContradictoryCall(Collection<Literal> collection) {
        HashMap hashMap = new HashMap();
        for (Literal literal : collection) {
            if (literal instanceof CallLiteral) {
                CallLiteral callLiteral = (CallLiteral) literal;
                if (callLiteral.getPolarity() == CallPolarity.POSITIVE) {
                    ((Set) hashMap.computeIfAbsent(callLiteral.getTarget(), constraint -> {
                        return new HashSet();
                    })).add(callLiteral);
                }
            }
        }
        for (Literal literal2 : collection) {
            if (literal2 instanceof CallLiteral) {
                CallLiteral callLiteral2 = (CallLiteral) literal2;
                if (callLiteral2.getPolarity() == CallPolarity.NEGATIVE && contradicts(callLiteral2, (Set) hashMap.get(callLiteral2.getTarget()))) {
                    return true;
                }
            }
        }
        return false;
    }

    private boolean contradicts(CallLiteral callLiteral, Collection<CallLiteral> collection) {
        if (collection == null) {
            return false;
        }
        Iterator<CallLiteral> it = collection.iterator();
        while (it.hasNext()) {
            if (contradicts(callLiteral, it.next())) {
                return true;
            }
        }
        return false;
    }

    private boolean contradicts(CallLiteral callLiteral, CallLiteral callLiteral2) {
        Set<Variable> privateVariables = callLiteral.getPrivateVariables(this.positiveVariables);
        List<Variable> arguments = callLiteral.getArguments();
        List<Variable> arguments2 = callLiteral2.getArguments();
        int size = arguments.size();
        for (int i = 0; i < size; i++) {
            Variable variable = arguments.get(i);
            if (!privateVariables.contains(variable) && !variable.equals(arguments2.get(i))) {
                return false;
            }
        }
        return true;
    }
}
