/*
 * Decompiled with CFR 0.152.
 */
package org.rcsb.strucmotif.core;

import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.rcsb.strucmotif.core.TargetAssembler;
import org.rcsb.strucmotif.core.ThreadPool;
import org.rcsb.strucmotif.domain.Pair;
import org.rcsb.strucmotif.domain.StructureSearchContext;
import org.rcsb.strucmotif.domain.bucket.InvertedIndexBucket;
import org.rcsb.strucmotif.domain.motif.IndexSelectionResiduePairIdentifier;
import org.rcsb.strucmotif.domain.motif.InvertedIndexResiduePairIdentifier;
import org.rcsb.strucmotif.domain.motif.Overlap;
import org.rcsb.strucmotif.domain.motif.ResiduePairDescriptor;
import org.rcsb.strucmotif.domain.motif.ResiduePairOccurrence;
import org.rcsb.strucmotif.domain.query.StructureParameters;
import org.rcsb.strucmotif.domain.query.StructureQuery;
import org.rcsb.strucmotif.domain.query.StructureQueryStructure;
import org.rcsb.strucmotif.domain.result.StructureSearchResult;
import org.rcsb.strucmotif.domain.result.TargetStructure;
import org.rcsb.strucmotif.domain.structure.IndexSelection;
import org.rcsb.strucmotif.domain.structure.LabelSelection;
import org.rcsb.strucmotif.domain.structure.ResidueType;
import org.rcsb.strucmotif.io.InvertedIndex;
import org.rcsb.strucmotif.io.StructureIndexProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class TargetAssemblerImpl
implements TargetAssembler {
    private static final Logger logger = LoggerFactory.getLogger(TargetAssemblerImpl.class);
    private final ThreadPool threadPool;
    private final StructureIndexProvider structureIndexProvider;

    @Autowired
    public TargetAssemblerImpl(ThreadPool threadPool, StructureIndexProvider structureIndexProvider) {
        this.threadPool = threadPool;
        this.structureIndexProvider = structureIndexProvider;
    }

    @Override
    public void assemble(StructureSearchContext context) throws ExecutionException, InterruptedException {
        StructureQuery query = context.getQuery();
        StructureQueryStructure queryStructure = query.getQueryStructure();
        StructureParameters parameters = query.getParameters();
        InvertedIndex invertedIndex = context.getInvertedIndex();
        StructureSearchResult result = context.getResult();
        int backboneDistanceTolerance = parameters.getBackboneDistanceTolerance();
        int sideChainDistanceTolerance = parameters.getSideChainDistanceTolerance();
        int angleTolerance = parameters.getAngleTolerance();
        Map<LabelSelection, Set<ResidueType>> labelSelectionExchanges = query.getExchanges();
        Map<IndexSelection, Set> exchanges = labelSelectionExchanges.entrySet().stream().collect(Collectors.toMap(entry -> {
            LabelSelection labelSelection = (LabelSelection)entry.getKey();
            int residueIndex = queryStructure.getStructure().getResidueIndex(labelSelection.getLabelAsymId(), labelSelection.getLabelSeqId());
            return new IndexSelection(labelSelection.getStructOperId(), residueIndex);
        }, Map.Entry::getValue));
        Set<Integer> searchSpace = this.structureIndexProvider.selectByResultsContentType(query.getResultsContentType());
        Set allowed = query.getAllowedStructures().stream().map(this.structureIndexProvider::selectStructureIndex).collect(Collectors.toSet());
        Set ignored = query.getExcludedStructures().stream().map(this.structureIndexProvider::selectStructureIndex).collect(Collectors.toSet());
        result.getTimings().pathsStart();
        int steps = queryStructure.getResiduePairOccurrences().size();
        for (int i = 0; i < steps; ++i) {
            long s = System.nanoTime();
            ResiduePairOccurrence residuePairOccurrence = queryStructure.getResiduePairOccurrences().get(i);
            ResiduePairDescriptor residuePairDescriptor = residuePairOccurrence.getResiduePairDescriptor();
            Map residuePairIdentifiers = this.threadPool.submit(() -> residuePairOccurrence.residuePairDescriptorsByTolerance(backboneDistanceTolerance, sideChainDistanceTolerance, angleTolerance, exchanges).flatMap(descriptor -> this.select(invertedIndex, (ResiduePairDescriptor)descriptor, searchSpace, allowed, ignored)).collect(Collectors.toMap(Pair::getFirst, Pair::getSecond, TargetAssemblerImpl::concat))).get();
            this.consume(context, residuePairIdentifiers);
            if (i + 1 < steps) {
                Set<Integer> keys = result.getTargetStructures().keySet();
                if (i == 0 && allowed.isEmpty()) {
                    allowed.addAll(keys);
                } else {
                    allowed.removeIf(v -> !keys.contains(v));
                }
            }
            logger.info("[{}] Consumed {} in {} ms - {} valid target structures remaining", new Object[]{context.getId(), residuePairDescriptor, (System.nanoTime() - s) / 1000L / 1000L, result.getTargetStructures().size()});
            if (i <= 0 || i + 1 >= steps || !allowed.isEmpty()) continue;
            logger.info("[{}] No more valid extensions - terminating early", (Object)context.getId());
            break;
        }
        result.getTimings().pathsStop();
        int pathCount = result.getTargetStructures().values().stream().mapToInt(TargetStructure::getNumberOfValidPaths).sum();
        int structureCount = result.getTargetStructures().size();
        logger.info("[{}] Found {} valid paths ({} target structures) in {} ms", new Object[]{context.getId(), pathCount, structureCount, result.getTimings().getPathsTime()});
        result.setNumberOfPaths(pathCount);
        result.setNumberOfTargetStructures(structureCount);
    }

    private static <T> T[] concat(T[] first, T[] second) {
        T[] result = Arrays.copyOf(first, first.length + second.length);
        System.arraycopy(second, 0, result, first.length, second.length);
        return result;
    }

    private Stream<Pair<Integer, InvertedIndexResiduePairIdentifier[]>> select(InvertedIndex invertedIndex, ResiduePairDescriptor descriptor, Set<Integer> searchSpace, Set<Integer> allowed, Set<Integer> ignored) {
        InvertedIndexBucket bucket = invertedIndex.select(descriptor);
        boolean ambiguous = descriptor.isAmbiguous();
        Pair[] out = new Pair[bucket.getStructureCount()];
        int i = 0;
        while (bucket.hasNextStructure()) {
            int j;
            InvertedIndexResiduePairIdentifier[] identifiers;
            bucket.moveStructure();
            int structureIndex = bucket.getStructureIndex();
            if (!allowed.isEmpty() && !allowed.contains(structureIndex) || ignored.contains(structureIndex) || searchSpace != null && !searchSpace.contains(structureIndex)) continue;
            int[] occurrencePositions = bucket.getOccurrencePositions();
            if (ambiguous) {
                identifiers = new InvertedIndexResiduePairIdentifier[occurrencePositions.length * 2];
                for (j = 0; j < occurrencePositions.length; ++j) {
                    int o = occurrencePositions[j];
                    int indexA = bucket.getIndex(o);
                    int indexB = bucket.getIndex(o + 1);
                    String structOperIdA = bucket.getStructOperId(o);
                    String structOperIdB = bucket.getStructOperId(o + 1);
                    identifiers[2 * j] = new InvertedIndexResiduePairIdentifier(indexA, indexB, structOperIdA, structOperIdB);
                    identifiers[2 * j + 1] = new InvertedIndexResiduePairIdentifier(indexB, indexA, structOperIdB, structOperIdA);
                }
            } else {
                identifiers = new InvertedIndexResiduePairIdentifier[occurrencePositions.length];
                for (j = 0; j < occurrencePositions.length; ++j) {
                    identifiers[j] = this.createResiduePairIdentifier(bucket, descriptor.isFlipped(), occurrencePositions[j]);
                }
            }
            out[i++] = new Pair<Integer, InvertedIndexResiduePairIdentifier[]>(structureIndex, identifiers);
        }
        if (i == 0) {
            return Stream.empty();
        }
        return Arrays.stream(out).limit(i);
    }

    private InvertedIndexResiduePairIdentifier createResiduePairIdentifier(InvertedIndexBucket bucket, boolean flipped, int i) {
        if (!flipped) {
            return new InvertedIndexResiduePairIdentifier(bucket.getIndex(i), bucket.getIndex(i + 1), bucket.getStructOperId(i), bucket.getStructOperId(i + 1));
        }
        return new InvertedIndexResiduePairIdentifier(bucket.getIndex(i + 1), bucket.getIndex(i), bucket.getStructOperId(i + 1), bucket.getStructOperId(i));
    }

    private void consume(StructureSearchContext context, Map<Integer, InvertedIndexResiduePairIdentifier[]> data) throws ExecutionException, InterruptedException {
        StructureQuery query = context.getQuery();
        StructureSearchResult result = context.getResult();
        Map<Integer, TargetStructure> targetStructures = result.getTargetStructures();
        StructureQueryStructure queryStructure = query.getQueryStructure();
        if (targetStructures == null) {
            result.setTargetStructures(data.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, v -> new TargetStructure((Integer)v.getKey(), (InvertedIndexResiduePairIdentifier[])v.getValue()))));
        } else {
            int pathGeneration = result.incrementAndGetPathGeneration();
            Overlap[] overlapProfile = new Overlap[pathGeneration];
            for (int i = 0; i < pathGeneration; ++i) {
                overlapProfile[i] = Overlap.ofResiduePairIdentifiers((IndexSelectionResiduePairIdentifier)queryStructure.getResiduePairIdentifiers().get(i), (IndexSelectionResiduePairIdentifier)queryStructure.getResiduePairIdentifiers().get(pathGeneration));
            }
            result.setTargetStructures(this.threadPool.submit(() -> targetStructures.entrySet().parallelStream().filter(entry -> {
                InvertedIndexResiduePairIdentifier[] residuePairIdentifiers = (InvertedIndexResiduePairIdentifier[])data.get(entry.getKey());
                if (residuePairIdentifiers == null) {
                    return false;
                }
                return ((TargetStructure)entry.getValue()).consume(residuePairIdentifiers, overlapProfile);
            }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))).get());
        }
    }
}

