package io.virtdata.core;

import io.virtdata.api.DataMapperLibrary;
import io.virtdata.api.ValueType;
import io.virtdata.api.VirtDataFunctionLibrary;
import io.virtdata.api.composers.FunctionAssembly;
import io.virtdata.ast.Expression;
import io.virtdata.ast.FunctionCall;
import io.virtdata.ast.VirtDataFlow;
import io.virtdata.parser.VirtDataDSL;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ClassUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/virtdata/core/VirtDataComposer.class */
public class VirtDataComposer {
    private static final String PREAMBLE = "compose ";
    private final VirtDataFunctionLibrary functionLibrary;
    private static final Logger logger = LoggerFactory.getLogger(DataMapperLibrary.class);
    private static final MethodHandles.Lookup lookup = MethodHandles.publicLookup();

    public VirtDataComposer(VirtDataFunctionLibrary virtDataFunctionLibrary) {
        this.functionLibrary = virtDataFunctionLibrary;
    }

    public VirtDataComposer() {
        this.functionLibrary = VirtDataLibraries.get();
    }

    public Optional<ResolvedFunction> resolveFunctionFlow(String str) {
        VirtDataDSL.ParseResult parse = VirtDataDSL.parse(str.startsWith(PREAMBLE) ? str.substring(8) : str);
        if (parse.throwable != null) {
            throw new RuntimeException(parse.throwable);
        }
        return resolveFunctionFlow(parse.flow);
    }

    public ResolverDiagnostics resolveDiagnosticFunctionFlow(String str) {
        VirtDataDSL.ParseResult parse = VirtDataDSL.parse(str.startsWith(PREAMBLE) ? str.substring(8) : str);
        if (parse.throwable != null) {
            throw new RuntimeException(parse.throwable);
        }
        return resolveDiagnosticFunctionFlow(parse.flow);
    }

    public ResolverDiagnostics resolveDiagnosticFunctionFlow(VirtDataFlow virtDataFlow) {
        ResolverDiagnostics resolverDiagnostics = new ResolverDiagnostics();
        resolverDiagnostics.trace("processing flow " + virtDataFlow.toString() + " from output to input");
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        Optional map = Optional.ofNullable(virtDataFlow.getLastExpression().getCall().getOutputType()).map(ValueType::valueOfClassName).map((v0) -> {
            return v0.getValueClass();
        });
        linkedList2.add(new HashSet());
        map.ifPresent(cls -> {
            ((Set) linkedList2.get(0)).add(cls);
        });
        resolverDiagnostics.trace("working backwards from " + (virtDataFlow.getExpressions().size() - 1));
        for (int size = virtDataFlow.getExpressions().size() - 1; size >= 0; size--) {
            FunctionCall call = ((Expression) virtDataFlow.getExpressions().get(size)).getCall();
            resolverDiagnostics.trace("resolving args for " + call.toString());
            LinkedList linkedList3 = new LinkedList();
            String functionName = call.getFunctionName();
            Class<?> classOfType = ValueType.classOfType(call.getInputType());
            Class<?> classOfType2 = ValueType.classOfType(call.getOutputType());
            try {
                Object[] populateFunctions = populateFunctions(resolverDiagnostics, call.getArguments());
                resolverDiagnostics.trace("resolved args: ");
                for (Object obj : populateFunctions) {
                    resolverDiagnostics.trace(" " + obj.getClass().getSimpleName() + ": " + obj.toString());
                }
                List<ResolvedFunction> resolveFunctions = this.functionLibrary.resolveFunctions(classOfType2, classOfType, functionName, populateFunctions);
                if (resolveFunctions.size() == 0) {
                    return resolverDiagnostics.error(new RuntimeException("Unable to find even one function for " + call));
                }
                resolverDiagnostics.trace(" resolved functions:");
                resolverDiagnostics.trace(summarize(resolveFunctions));
                linkedList3.addAll(resolveFunctions);
                linkedList.addFirst(linkedList3);
                linkedList2.addFirst((Set) linkedList3.stream().map((v0) -> {
                    return v0.getInputClass();
                }).collect(Collectors.toSet()));
            } catch (Exception e) {
                return resolverDiagnostics.error(e);
            }
        }
        if (!((Set) linkedList2.peekFirst()).contains(Long.TYPE)) {
            return resolverDiagnostics.error(new RuntimeException("There is no initial function which accepts a long input. Function chain, after type filtering: \n" + summarizeBulk(linkedList)));
        }
        removeNonLongFunctions((List) linkedList.getFirst());
        List<ResolvedFunction> optimizePath = optimizePath(linkedList, ValueType.classOfType(virtDataFlow.getLastExpression().getCall().getOutputType()));
        if (optimizePath.size() == 1) {
            resolverDiagnostics.trace("FUNCTION resolution succeeded (single): '" + virtDataFlow.toString() + "'");
            return resolverDiagnostics.setResolvedFunction(optimizePath.get(0));
        }
        FunctionAssembly functionAssembly = new FunctionAssembly();
        resolverDiagnostics.trace("composed summary: " + summarize(optimizePath));
        boolean z = true;
        resolverDiagnostics.trace("FUNCTION chain selected: (multi) '" + summarize(optimizePath) + "'");
        for (ResolvedFunction resolvedFunction : optimizePath) {
            try {
                functionAssembly.andThen(resolvedFunction.getFunctionObject());
                if (!resolvedFunction.isThreadSafe()) {
                    z = false;
                }
            } catch (Exception e2) {
                return resolverDiagnostics.error(new RuntimeException("FUNCTION resolution failed: '" + virtDataFlow.toString() + "': " + e2.toString()));
            }
        }
        ResolvedFunction resolvedFunction2 = functionAssembly.getResolvedFunction(z);
        resolverDiagnostics.trace("FUNCTION resolution succeeded (lambda): '" + virtDataFlow.toString() + "'");
        return resolverDiagnostics.setResolvedFunction(resolvedFunction2);
    }

    public Optional<ResolvedFunction> resolveFunctionFlow(VirtDataFlow virtDataFlow) {
        return resolveDiagnosticFunctionFlow(virtDataFlow).getResolvedFunction();
    }

    private Object[] populateFunctions(ResolverDiagnostics resolverDiagnostics, Object[] objArr) {
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj instanceof FunctionCall) {
                FunctionCall functionCall = (FunctionCall) obj;
                String functionName = functionCall.getFunctionName();
                Class<?> classOfType = ValueType.classOfType(functionCall.getInputType());
                Class<?> classOfType2 = ValueType.classOfType(functionCall.getOutputType());
                Object[] arguments = functionCall.getArguments();
                resolverDiagnostics.trace("resolving argument as function '" + functionCall.toString() + "'");
                List<ResolvedFunction> resolveFunctions = this.functionLibrary.resolveFunctions(classOfType2, classOfType, functionName, populateFunctions(resolverDiagnostics, arguments));
                if (resolveFunctions.size() == 0) {
                    throw new RuntimeException("Unable to resolve even one function for argument: " + functionCall);
                }
                objArr[i] = resolveFunctions.get(0).getFunctionObject();
            }
        }
        return objArr;
    }

    private void removeNonLongFunctions(List<ResolvedFunction> list) {
        LinkedList linkedList = new LinkedList();
        for (ResolvedFunction resolvedFunction : list) {
            if (!resolvedFunction.getInputClass().isAssignableFrom(Long.TYPE)) {
                logger.trace("input type " + resolvedFunction.getInputClass().getCanonicalName() + " is not assignable from long");
                linkedList.add(resolvedFunction);
            }
        }
        if (linkedList.size() > 0 && linkedList.size() == list.size()) {
            throw new RuntimeException("removeNonLongFunctions would remove all functions: " + list);
        }
        list.removeAll(linkedList);
    }

    private String summarize(List<ResolvedFunction> list) {
        return (String) list.stream().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.joining("|"));
    }

    private String summarizeBulk(List<List<ResolvedFunction>> list) {
        LinkedList linkedList = new LinkedList();
        list.forEach(list2 -> {
            linkedList.add((List) list2.stream().map((v0) -> {
                return String.valueOf(v0);
            }).collect(Collectors.toList()));
        });
        String str = (String) linkedList.stream().map(list3 -> {
            return (String) list3.stream().map((v0) -> {
                return String.valueOf(v0);
            }).collect(Collectors.joining("|\n"));
        }).collect(Collectors.joining("\n\n"));
        StringBuilder sb = new StringBuilder();
        sb.append("---\\\\\n").append(str).append("\n---////\n");
        return sb.toString();
    }

    private List<ResolvedFunction> optimizePath(List<List<ResolvedFunction>> list, Class<?> cls) {
        List<ResolvedFunction> list2 = null;
        int i = -1;
        int i2 = 0;
        while (i != 0) {
            i2++;
            i = 0 + reduceByRequiredResultsType(list.get(list.size() - 1), cls);
            if (list.size() > 1) {
                int i3 = 0;
                for (List<ResolvedFunction> list3 : list) {
                    i3++;
                    if (list2 != null && list3 != null && i == 0) {
                        i += reduceByDirectTypes(list2, list3);
                        if (i == 0) {
                            i += reduceByAssignableTypes(list2, list3, false);
                            if (i == 0) {
                                i += reduceByAssignableTypes(list2, list3, true);
                                if (i == 0) {
                                    i += reduceByPreferredTypes(list2, list3);
                                }
                            }
                        }
                    }
                    list2 = list3;
                }
                list2 = null;
            } else {
                i += reduceByPreferredResultTypes(list.get(0));
            }
        }
        return (List) list.stream().map(list4 -> {
            return (ResolvedFunction) list4.get(0);
        }).collect(Collectors.toList());
    }

    private int reduceByRequiredResultsType(List<ResolvedFunction> list, Class<?> cls) {
        int i = 0;
        Iterator it = new LinkedList(list).iterator();
        while (it.hasNext()) {
            ResolvedFunction resolvedFunction = (ResolvedFunction) it.next();
            if (cls != null && !ClassUtils.isAssignable(resolvedFunction.getResultClass(), cls, true)) {
                list.remove(resolvedFunction);
                logger.trace("BY-REQUIRED-RESULT-TYPE removed function '" + resolvedFunction + "' because is not assignable to " + cls);
                i++;
            }
        }
        if (list.size() == 0) {
            throw new RuntimeException("BY-REQUIRED-RESULT-TYPE No end funcs were found which are assignable to " + cls);
        }
        return i;
    }

    private int reduceByPreferredResultTypes(List<ResolvedFunction> list) {
        int i = 0;
        if (list.size() > 1) {
            i = 0 + (list.size() - 1);
            list.sort(ResolvedFunction.PREFERRED_TYPE_COMPARATOR);
            while (list.size() > 1) {
                logger.trace("BY-SINGLE-PREFERRED-TYPE removing func " + list.get(list.size() - 1) + " because " + list.get(0) + " has more preferred types.");
                list.remove(list.size() - 1);
            }
        }
        return i;
    }

    private int reduceByPreferredTypes(List<ResolvedFunction> list, List<ResolvedFunction> list2) {
        int i = 0;
        if (list.size() > 1) {
            i = 0 + (list.size() - 1);
            list.sort(ResolvedFunction.PREFERRED_TYPE_COMPARATOR);
            while (list.size() > 1) {
                logger.trace("BY-PREV-PREFERRED-TYPE removing func " + list.get(list.size() - 1) + " because " + list.get(0) + " has more preferred types.");
                list.remove(list.size() - 1);
            }
        } else if (list2.size() > 1) {
            i = 0 + (list2.size() - 1);
            list2.sort(ResolvedFunction.PREFERRED_TYPE_COMPARATOR);
            while (list2.size() > 1) {
                logger.trace("BY-NEXT-PREFERRED-TYPE removing func " + list2.get(list2.size() - 1) + " because " + list2.get(0) + " has more preferred types.");
                list2.remove(list2.size() - 1);
            }
        }
        return i;
    }

    private int reduceByDirectTypes(List<ResolvedFunction> list, List<ResolvedFunction> list2) {
        int i = 0;
        Set<Class<?>> outputs = getOutputs(list);
        Stream<Class<?>> stream = getInputs(list2).stream();
        Objects.requireNonNull(outputs);
        Set set = (Set) stream.filter((v1) -> {
            return r1.contains(v1);
        }).collect(Collectors.toCollection(HashSet::new));
        if (set.size() > 0) {
            ArrayList arrayList = new ArrayList();
            for (ResolvedFunction resolvedFunction : list2) {
                if (!set.contains(resolvedFunction.getArgType())) {
                    logger.trace("BY-DIRECT-TYPE removing next func: " + resolvedFunction + " because its input types are not satisfied by any previous func");
                    arrayList.add(resolvedFunction);
                    i++;
                }
            }
            list2.removeAll(arrayList);
        }
        return i;
    }

    private int reduceByAssignableTypes(List<ResolvedFunction> list, List<ResolvedFunction> list2, boolean z) {
        Set<Class<?>> outputs = getOutputs(list);
        Set<Class<?>> inputs = getInputs(list2);
        HashSet hashSet = new HashSet();
        for (Class<?> cls : inputs) {
            Iterator<Class<?>> it = outputs.iterator();
            while (it.hasNext()) {
                if (ClassUtils.isAssignable(it.next(), cls, z)) {
                    hashSet.add(cls);
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        for (ResolvedFunction resolvedFunction : list2) {
            if (!hashSet.contains(resolvedFunction.getInputClass())) {
                arrayList.add(resolvedFunction);
            }
        }
        if (arrayList.size() == list2.size()) {
            logger.trace("BY-ASSIGNABLE-TYPE Not removing remaining " + list2.size() + " next funcs " + (z ? "with autoboxing " : "") + "because no functions would be left.");
            return 0;
        }
        arrayList.forEach(resolvedFunction2 -> {
            logger.trace("BY-ASSIGNABLE-TYPE removing next func: " + resolvedFunction2 + " because its input types are not assignable from any of the previous funcs");
        });
        list2.removeAll(arrayList);
        return arrayList.size();
    }

    private Set<Class<?>> getOutputs(List<ResolvedFunction> list) {
        HashSet hashSet = new HashSet();
        Iterator<ResolvedFunction> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getResultClass());
        }
        return hashSet;
    }

    private Set<Class<?>> getInputs(List<ResolvedFunction> list) {
        HashSet hashSet = new HashSet();
        Iterator<ResolvedFunction> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getArgType());
        }
        return hashSet;
    }
}
