package org.vertexium.accumulo;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.accumulo.core.client.IteratorSetting;
import org.apache.accumulo.core.client.ScannerBase;
import org.apache.accumulo.core.data.Key;
import org.apache.accumulo.core.data.Value;
import org.apache.accumulo.core.trace.Span;
import org.apache.accumulo.core.trace.Trace;
import org.vertexium.Authorizations;
import org.vertexium.ElementType;
import org.vertexium.FetchHint;
import org.vertexium.FindPathOptions;
import org.vertexium.Path;
import org.vertexium.ProgressCallback;
import org.vertexium.VertexiumException;
import org.vertexium.accumulo.iterator.ConnectedVertexIdsIterator;
import org.vertexium.accumulo.util.RangeUtils;
import org.vertexium.util.IterableUtils;
import org.vertexium.util.StreamUtils;
import org.vertexium.util.VertexiumLogger;
import org.vertexium.util.VertexiumLoggerFactory;

/* loaded from: input_file:org/vertexium/accumulo/AccumuloFindPathStrategy.class */
public class AccumuloFindPathStrategy {
    private static final VertexiumLogger LOGGER = VertexiumLoggerFactory.getLogger(AccumuloFindPathStrategy.class);
    private final AccumuloGraph graph;
    private final FindPathOptions options;
    private final ProgressCallback progressCallback;
    private final Authorizations authorizations;
    private final String[] deflatedLabels;
    private final String[] deflatedExcludedLabels;

    public AccumuloFindPathStrategy(AccumuloGraph accumuloGraph, FindPathOptions findPathOptions, ProgressCallback progressCallback, Authorizations authorizations) {
        this.graph = accumuloGraph;
        this.options = findPathOptions;
        this.progressCallback = progressCallback;
        this.authorizations = authorizations;
        this.deflatedLabels = deflateLabels(accumuloGraph.getNameSubstitutionStrategy(), findPathOptions.getLabels());
        this.deflatedExcludedLabels = deflateLabels(accumuloGraph.getNameSubstitutionStrategy(), findPathOptions.getExcludedLabels());
    }

    private static String[] deflateLabels(AccumuloNameSubstitutionStrategy accumuloNameSubstitutionStrategy, String[] strArr) {
        if (strArr == null) {
            return null;
        }
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr2[i] = accumuloNameSubstitutionStrategy.deflate(strArr[i]);
        }
        return strArr2;
    }

    public Iterable<Path> findPaths() {
        this.progressCallback.progress(0.0d, ProgressCallback.Step.FINDING_PATH);
        ArrayList arrayList = new ArrayList();
        if (this.options.getMaxHops() < 1) {
            throw new IllegalArgumentException("maxHops cannot be less than 1");
        }
        if (this.options.getMaxHops() == 1) {
            if (getConnectedVertexIds(this.options.getSourceVertexId()).contains(this.options.getDestVertexId())) {
                arrayList.add(new Path(new String[]{this.options.getSourceVertexId(), this.options.getDestVertexId()}));
            }
        } else if (this.options.getMaxHops() == 2) {
            findPathsSetIntersection(arrayList);
        } else {
            findPathsBreadthFirst(arrayList, this.options.getSourceVertexId(), this.options.getDestVertexId(), this.options.getMaxHops());
        }
        this.progressCallback.progress(1.0d, ProgressCallback.Step.COMPLETE);
        return arrayList;
    }

    private void findPathsSetIntersection(List<Path> list) {
        String sourceVertexId = this.options.getSourceVertexId();
        String destVertexId = this.options.getDestVertexId();
        HashSet hashSet = new HashSet();
        hashSet.add(sourceVertexId);
        hashSet.add(destVertexId);
        Map<String, Set<String>> connectedVertexIds = getConnectedVertexIds(hashSet);
        this.progressCallback.progress(0.1d, ProgressCallback.Step.SEARCHING_SOURCE_VERTEX_EDGES);
        Set<String> set = connectedVertexIds.get(sourceVertexId);
        if (set == null) {
            return;
        }
        this.progressCallback.progress(0.3d, ProgressCallback.Step.SEARCHING_DESTINATION_VERTEX_EDGES);
        Set<String> set2 = connectedVertexIds.get(destVertexId);
        if (set2 == null) {
            return;
        }
        if (set.contains(destVertexId)) {
            list.add(new Path(new String[]{sourceVertexId, destVertexId}));
            if (this.options.isGetAnyPath()) {
                return;
            }
        }
        this.progressCallback.progress(0.6d, ProgressCallback.Step.MERGING_EDGES);
        set.retainAll(set2);
        this.progressCallback.progress(0.9d, ProgressCallback.Step.ADDING_PATHS);
        list.addAll((Collection) set.stream().map(str -> {
            return new Path(new String[]{sourceVertexId, str, destVertexId});
        }).collect(Collectors.toList()));
    }

    private void findPathsBreadthFirst(List<Path> list, String str, String str2, int i) {
        Map<String, Set<String>> connectedVertexIds = getConnectedVertexIds(str, str2);
        for (int i2 = 2; i2 < i; i2++) {
            this.progressCallback.progress(i2 / i, ProgressCallback.Step.FINDING_PATH);
            HashSet hashSet = new HashSet();
            Iterator<Map.Entry<String, Set<String>>> it = connectedVertexIds.entrySet().iterator();
            while (it.hasNext()) {
                hashSet.addAll(it.next().getValue());
            }
            hashSet.removeAll(connectedVertexIds.keySet());
            connectedVertexIds.putAll(getConnectedVertexIds(hashSet));
        }
        this.progressCallback.progress(0.9d, ProgressCallback.Step.ADDING_PATHS);
        findPathsRecursive(connectedVertexIds, list, str, str2, i, new HashSet(), new Path(new String[]{str}), this.progressCallback);
    }

    private void findPathsRecursive(Map<String, Set<String>> map, List<Path> list, String str, String str2, int i, Set<String> set, Path path, ProgressCallback progressCallback) {
        Set<String> set2;
        if (this.options.isGetAnyPath() && list.size() == 1) {
            return;
        }
        set.add(str);
        if (str.equals(str2)) {
            list.add(path);
        } else if (i > 0 && (set2 = map.get(str)) != null) {
            for (String str3 : set2) {
                if (!set.contains(str3)) {
                    findPathsRecursive(map, list, str3, str2, i - 1, set, new Path(path, str3), progressCallback);
                }
            }
        }
        set.remove(str);
    }

    private Set<String> getConnectedVertexIds(String str) {
        HashSet hashSet = new HashSet();
        hashSet.add(str);
        Set<String> set = getConnectedVertexIds(hashSet).get(str);
        return set == null ? new HashSet() : set;
    }

    private Map<String, Set<String>> getConnectedVertexIds(String str, String str2) {
        HashSet hashSet = new HashSet();
        hashSet.add(str);
        hashSet.add(str2);
        return getConnectedVertexIds(hashSet);
    }

    private Map<String, Set<String>> getConnectedVertexIds(Set<String> set) {
        Span start = Trace.start("getConnectedVertexIds");
        try {
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace("getConnectedVertexIds:\n  %s", new Object[]{IterableUtils.join(set, "\n  ")});
            }
            if (set.size() == 0) {
                HashMap hashMap = new HashMap();
                start.stop();
                return hashMap;
            }
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = set.iterator();
            while (it.hasNext()) {
                arrayList.add(RangeUtils.createRangeFromString(it.next()));
            }
            ScannerBase<Map.Entry> createElementScanner = this.graph.createElementScanner(FetchHint.EDGE_REFS, ElementType.VERTEX, 1, null, null, arrayList, false, this.authorizations);
            IteratorSetting iteratorSetting = new IteratorSetting(1000, ConnectedVertexIdsIterator.class.getSimpleName(), ConnectedVertexIdsIterator.class);
            ConnectedVertexIdsIterator.setLabels(iteratorSetting, this.deflatedLabels);
            ConnectedVertexIdsIterator.setExcludedLabels(iteratorSetting, this.deflatedExcludedLabels);
            createElementScanner.addScanIterator(iteratorSetting);
            long currentTimeMillis = System.currentTimeMillis();
            try {
                HashMap hashMap2 = new HashMap();
                for (Map.Entry entry : createElementScanner) {
                    try {
                        Map doVerticesExist = this.graph.doVerticesExist(ConnectedVertexIdsIterator.decodeValue((Value) entry.getValue()), this.authorizations);
                        hashMap2.put(((Key) entry.getKey()).getRow().toString(), (Set) StreamUtils.stream(new Iterable[]{doVerticesExist.keySet()}).filter(str -> {
                            return ((Boolean) doVerticesExist.getOrDefault(str, false)).booleanValue();
                        }).collect(Collectors.toSet()));
                    } catch (IOException e) {
                        throw new VertexiumException("Could not decode vertex ids for row: " + ((Key) entry.getKey()).toString(), e);
                    }
                }
                start.stop();
                return hashMap2;
            } finally {
                createElementScanner.close();
                AccumuloGraph.GRAPH_LOGGER.logEndIterator(System.currentTimeMillis() - currentTimeMillis);
            }
        } catch (Throwable th) {
            start.stop();
            throw th;
        }
    }
}
