package org.rcsb.strucmotif.update;

import org.rcsb.cif.CifIO;
import org.rcsb.cif.schema.StandardSchemata;
import org.rcsb.cif.schema.mm.MmCifFile;
import org.rcsb.cif.schema.mm.PdbxAuditRevisionHistory;
import org.rcsb.cif.schema.mm.PdbxStructAssemblyGen;
import org.rcsb.strucmotif.config.MotifSearchConfig;
import org.rcsb.strucmotif.core.ThreadPool;
import org.rcsb.strucmotif.domain.ResidueGraph;
import org.rcsb.strucmotif.domain.Revision;
import org.rcsb.strucmotif.domain.StructureInformation;
import org.rcsb.strucmotif.domain.identifier.StructureIdentifier;
import org.rcsb.strucmotif.domain.motif.ResiduePairDescriptor;
import org.rcsb.strucmotif.domain.motif.ResiduePairIdentifier;
import org.rcsb.strucmotif.domain.structure.Structure;
import org.rcsb.strucmotif.io.StructureDataProvider;
import org.rcsb.strucmotif.persistence.InvertedIndex;
import org.rcsb.strucmotif.persistence.StateRepository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
import org.springframework.boot.autoconfigure.domain.EntityScan;
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
import org.springframework.context.annotation.ComponentScan;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.MatchResult;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

@SpringBootApplication(exclude = { MongoAutoConfiguration.class, MongoDataAutoConfiguration.class })
@ComponentScan({"org.rcsb.strucmotif"})
@EntityScan("org.rcsb.strucmotif")
public class MotifSearchUpdate implements CommandLineRunner {
    private static final Logger logger = LoggerFactory.getLogger(MotifSearchUpdate.class);

    public static void main(String[] args) {
        SpringApplication.run(MotifSearchUpdate.class, args);
    }

    private final StateRepository stateRepository;
    private final StructureDataProvider structureDataProvider;
    private final InvertedIndex invertedIndex;
    private final MotifSearchConfig motifSearchConfig;
    private final ThreadPool threadPool;

    @Autowired
    public MotifSearchUpdate(StateRepository stateRepository, StructureDataProvider structureDataProvider, InvertedIndex invertedIndex, MotifSearchConfig motifSearchConfig, ThreadPool threadPool) {
        this.stateRepository = stateRepository;
        this.structureDataProvider = structureDataProvider;
        this.invertedIndex = invertedIndex;
        this.motifSearchConfig = motifSearchConfig;
        this.threadPool = threadPool;
    }

    public void run(String[] args) throws Exception {
        if (args.length < 1) {
            System.out.println("Too few arguments");
            System.out.println("Usage: java -Xmx12G -jar update.jar operation ...");
            System.out.println("Valid operation values: " + Arrays.toString(Operation.values()));
            System.out.println("Optionally: list of entry ids - (no argument performs null operation, use single argument 'full' for complete update)");
            System.out.println("If you want to update entries you have to explicitly remove them first");
            System.out.println("Example: java -Xmx12G -jar update.jar ADD 1acj 1exr 4hhb");
            return;
        }

        // determine identifiers requested by user - either provided collection or all currently reported identifiers by RCSB PDB
        Operation operation = Operation.resolve(args[0]);
        String[] ids = new String[args.length - 1];
        List<StructureIdentifier> requested;
        System.arraycopy(args, 1, ids, 0, ids.length);
        if (ids.length == 1 && ids[0].equalsIgnoreCase("full")) {
            requested = getAllIdentifiers();
        } else {
            requested = Arrays.stream(ids).map(StructureIdentifier::new).collect(Collectors.toList());
        }

        // check for sanity of internal state
        if (operation != Operation.RECOVER) {
            Collection<StructureIdentifier> dirtyStructureIdentifiers = stateRepository.selectDirty();
            if (dirtyStructureIdentifiers.size() > 0) {
                logger.warn("Update state is dirty - problematic identifiers:\n{}",
                        dirtyStructureIdentifiers);
                logger.info("Recovering from dirty state");
                remove(stateRepository.selectDirty());
            }
        }

        logger.info("Starting update - Operation: {}, {} ids ({})",
                operation,
                requested.size(),
                requested.stream()
                        .limit(5)
                        .map(id -> "\"" + id.getPdbId() + "\"")
                        .collect(Collectors.joining(", ",
                                "[",
                                requested.size() > 5 ? ", ...]" : "]")));

        switch (operation) {
            case ADD:
                add(getDeltaPlusIdentifiers(requested));
                break;
            case REMOVE:
                remove(getDeltaMinusIdentifiers(requested));
                break;
            case RECOVER:
                remove(stateRepository.selectDirty());
                break;
        }

        logger.info("Finished update operation");
    }

    public void add(Collection<StructureIdentifier> identifiers) throws ExecutionException, InterruptedException {
        long target = identifiers.size();
        logger.info("{} files to process in total", target);

        Partition<StructureIdentifier> partitions = new Partition<>(identifiers, motifSearchConfig.getUpdateChunkSize());
        logger.info("Formed {} partitions of {} structures",
                partitions.size(),
                motifSearchConfig.getUpdateChunkSize());

        Context context = new Context();

        // split into partitions and process
        for (int i = 0; i < partitions.size(); i++) {
            context.partitionContext = (i + 1) + " / " + partitions.size();

            List<StructureIdentifier> partition = partitions.get(i);
            logger.info("[{}] Start processing partition", context.partitionContext);

            context.structureCounter = new AtomicInteger();
            context.buffer = new ConcurrentHashMap<>();
            threadPool.submit(() -> {
                partition.parallelStream().forEach(id -> handleStructureIdentifier(id, context));
                return null;
            }).get();

            // mark as dirty only around index update
            stateRepository.insertDirty(partition);
            persist(context);
        }
    }

    static class Context {
        final Set<StructureInformation> processed;
        String partitionContext;
        Map<ResiduePairDescriptor, Map<StructureIdentifier, Collection<ResiduePairIdentifier>>> buffer;
        AtomicInteger structureCounter;

        public Context() {
            this.processed = Collections.synchronizedSet(new HashSet<>());
        }
    }

    private void handleStructureIdentifier(StructureIdentifier structureIdentifier, Context context) {
        int count = context.structureCounter.incrementAndGet();
        String structureContext = count + " / " + motifSearchConfig.getUpdateChunkSize() + "] [" + structureIdentifier.getPdbId();

        try {
            // write renumbered structure
            MmCifFile mmCifFile = CifIO.readFromInputStream(structureDataProvider.getOriginalInputStream(structureIdentifier)).as(StandardSchemata.MMCIF);
            Revision revision = getRevision(mmCifFile);
            Map<String, List<String>> assemblyInformation = getAssemblyInformation(mmCifFile);
            structureDataProvider.writeRenumbered(structureIdentifier, mmCifFile);
            context.processed.add(new StructureInformation(structureIdentifier, revision, assemblyInformation));
        } catch (IOException e) {
            throw new UncheckedIOException("cif parsing failed for " + structureIdentifier, e);
        }

        // fails when file is missing (should not happen) or does not contain valid polymer chain
        Structure structure;
        try {
            structure = structureDataProvider.readRenumbered(structureIdentifier);
        } catch (UncheckedIOException e) {
            // can 'safely' happen when obsolete entry was dropped from bcif data but still lingers in list
            logger.warn("[{}] [{}] Source file missing unexpectedly - obsolete entry?",
                    context.partitionContext,
                    structureContext,
                    e);
            return;
        } catch (UnsupportedOperationException e) {
            logger.warn("[{}] [{}] No valid polymer chains",
                    context.partitionContext,
                    structureContext);
            return;
        }

        try {
            ResidueGraph residueGraph = new ResidueGraph(structure, motifSearchConfig.getSquaredDistanceCutoff());

            // extract motifs
            AtomicInteger structureMotifCounter = new AtomicInteger();
            threadPool.submit(() -> {
                residueGraph.residuePairOccurrencesParallel()
                        .forEach(motifOccurrence -> {
                            ResiduePairDescriptor motifDescriptor = motifOccurrence.getResiduePairDescriptor();
                            ResiduePairIdentifier targetIdentifier = motifOccurrence.getResidueIdentifier();

                            Map<StructureIdentifier, Collection<ResiduePairIdentifier>> groupedTargetIdentifiers = context.buffer.computeIfAbsent(motifDescriptor, k -> Collections.synchronizedMap(new HashMap<>()));
                            Collection<ResiduePairIdentifier> targetIdentifiers = groupedTargetIdentifiers.computeIfAbsent(structureIdentifier, k -> Collections.synchronizedSet(new HashSet<>()));
                            targetIdentifiers.add(targetIdentifier);
                            structureMotifCounter.incrementAndGet();
                        });
                return null;
            }).get();
            logger.info("[{}] [{}] Extracted {} residue pairs",
                    context.partitionContext,
                    structureContext,
                    structureMotifCounter.get());
        } catch (Exception e) {
            logger.warn("[{}] [{}] Residue graph determination failed",
                    context.partitionContext,
                    structureContext,
                    e);
            // fail complete update
            throw new RuntimeException(e);
        }
    }

    private Map<String, List<String>> getAssemblyInformation(MmCifFile mmCifFile) {
        // TODO maybe this functionality should be part of Structures?
        /*
        loop_
        _pdbx_struct_assembly_gen.assembly_id
        _pdbx_struct_assembly_gen.oper_expression
        _pdbx_struct_assembly_gen.asym_id_list
        1 '(1-60)(61-88)'           A,B,C
        2 '(61-88)'                 A,B,C
        3 '(1-5)(61-88)'            A,B,C
        4 '(1,2,6,10,23,24)(61-88)' A,B,C
        5 '(1-5)(63-68)'            A,B,C
        6 '(1,10,23)(61,62,69-88)'  A,B,C
        7 '(P)(61-88)'              A,B,C
        #
         */
        PdbxStructAssemblyGen pdbxStructAssemblyGen = mmCifFile.getFirstBlock().getPdbxStructAssemblyGen();
        Map<String, List<String>> assemblyInformation = new LinkedHashMap<>();
        if (pdbxStructAssemblyGen.isDefined()) {
            for (int i = 0; i < pdbxStructAssemblyGen.getRowCount(); i++) {
                String assemblyId = pdbxStructAssemblyGen.getAssemblyId().get(i);
                String operExpression = pdbxStructAssemblyGen.getOperExpression().get(i);
                String asymIdList = pdbxStructAssemblyGen.getAsymIdList().get(i);
                List<String> operList = getOperList(operExpression, asymIdList);
                assemblyInformation.put(assemblyId, operList);
            }
        }
        return assemblyInformation;
    }

    private static final Pattern OPERATION_PATTERN = Pattern.compile("\\)\\(");
    private static final Pattern LIST_PATTERN = Pattern.compile(",");
    private List<String> getOperList(String operExpression, String asymIdList) {
        List<String> operations = new ArrayList<>();
        List<String> chains = LIST_PATTERN.splitAsStream(asymIdList).collect(Collectors.toList());
        String[] split = OPERATION_PATTERN.split(operExpression);
        if (split.length > 1) {
            List<String> ids1 = extractTransformationIds(split[0]);
            List<String> ids2 = extractTransformationIds(split[1]);
            for (String id1 : ids1) {
                for (String id2 : ids2) {
                    for (String chain : chains) {
                        operations.add(chain + "_" + id1 + "x" + id2);
                    }
                }
            }
        } else {
            for (String id : extractTransformationIds(operExpression)) {
                for (String chain : chains) {
                    operations.add(chain + "_" + id);
                }
            }
        }

        return operations;
    }

    private static final Pattern COMMA_PATTERN = Pattern.compile(",");
    private List<String> extractTransformationIds(String rawOperation) {
        String prepared = rawOperation.replace("(", "")
                .replace(")", "")
                .replace("'", "");

        return COMMA_PATTERN.splitAsStream(prepared)
                .flatMap(this::extractTransformationRanges)
                .collect(Collectors.toList());
    }

    private static final Pattern RANGE_PATTERN = Pattern.compile("-");
    private Stream<String> extractTransformationRanges(String raw) {
        String[] s = RANGE_PATTERN.split(raw);
        if (s.length == 1) {
            return Stream.of(raw);
        } else {
            return IntStream.range(Integer.parseInt(s[0]), Integer.parseInt(s[1]) + 1)
                    .mapToObj(String::valueOf);
        }
    }

    private Revision getRevision(MmCifFile mmCifFile) {
        PdbxAuditRevisionHistory pdbxAuditRevisionHistory = mmCifFile.getFirstBlock().getPdbxAuditRevisionHistory();
        int last = pdbxAuditRevisionHistory.getRowCount() - 1;
        return new Revision(pdbxAuditRevisionHistory.getMajorRevision().get(last), pdbxAuditRevisionHistory.getMinorRevision().get(last));
    }

    private void persist(Context context) throws ExecutionException, InterruptedException {
        logger.info("[{}] Persisting {} unique residue pair descriptors",
                context.partitionContext,
                context.buffer.size());

        final int bufferTotal = context.buffer.size();
        AtomicInteger bufferCount = new AtomicInteger();
        threadPool.submit(() -> {
            context.buffer.entrySet().parallelStream().forEach(entry -> {
                ResiduePairDescriptor full = entry.getKey();
                Map<StructureIdentifier, Collection<ResiduePairIdentifier>> output = entry.getValue();

                if (bufferCount.incrementAndGet() % 100000 == 0) {
                    logger.info("[{}] {} / {}",
                            context.partitionContext,
                            bufferCount,
                            bufferTotal);
                }

                invertedIndex.insert(full, output);

                // writing takes additional heap - ease burden by dropping processed output bins
                output.clear();
            });
            return null;
        }).get();

        context.buffer.clear();

        // processed contains all StructureIdentifiers + corresponding revision
        stateRepository.insertKnown(context.processed);
        stateRepository.deleteDirty(context.processed.stream().map(StructureInformation::getStructureIdentifier).collect(Collectors.toSet()));
        context.processed.clear();
    }

    public void remove(Collection<StructureIdentifier> identifiers) {
        AtomicInteger counter = new AtomicInteger();
        for (StructureIdentifier structureIdentifier : identifiers) {
            logger.info("[{}] Removing renumbered structure for entry: {}",
                    counter.incrementAndGet() + " / " + identifiers.size(),
                    structureIdentifier);
            structureDataProvider.deleteRenumbered(structureIdentifier);
        }

        // inverted index is expensive and should be done as batch
        if (identifiers.size() > 0) {
            invertedIndex.delete(identifiers);
            stateRepository.deleteKnown(identifiers);
            stateRepository.deleteDirty(identifiers);
        }
        logger.info("Finished removal operation");
    }

    private static final Pattern ENTRY_ID_PATTERN = Pattern.compile("[0-9][0-9A-Z]{3}");
    public List<StructureIdentifier> getAllIdentifiers() throws IOException {
        logger.info("Retrieving current entry list from {}", MotifSearchConfig.RCSB_ENTRY_LIST);
        String response;
        try (InputStream inputStream = new URL(MotifSearchConfig.RCSB_ENTRY_LIST).openStream()) {
            try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) {
                response = bufferedReader.lines().collect(Collectors.joining(System.lineSeparator()));
            }
        }
        return ENTRY_ID_PATTERN.matcher(response)
                .results()
                .map(MatchResult::group)
                .map(String::toLowerCase)
                .map(StructureIdentifier::new)
                .collect(Collectors.toList());
    }

    /**
     * Determine all IDs that need to be added to the archive.
     * @param requested the requested update
     * @return array of IDs that need to be processed for the given context
     */
    public Collection<StructureIdentifier> getDeltaPlusIdentifiers(Collection<StructureIdentifier> requested) {
        Collection<StructureIdentifier> known = stateRepository.selectKnown().stream().map(StructureInformation::getStructureIdentifier).collect(Collectors.toSet());
        if (known.isEmpty()) {
            logger.warn("No existing data - starting from scratch");
            return requested;
        } else {
            return requested.stream()
                    .filter(id -> !known.contains(id))
                    .collect(Collectors.toSet());
        }
    }

    /**
     * Determine all IDs that need to be removed from the archive.
     * @param requested the requested update
     * @return array of IDs that need to be remove for the given context
     */
    public Collection<StructureIdentifier> getDeltaMinusIdentifiers(Collection<StructureIdentifier> requested) {
        Collection<StructureIdentifier> known = stateRepository.selectKnown().stream().map(StructureInformation::getStructureIdentifier).collect(Collectors.toSet());
        if (known.isEmpty()) {
            logger.warn("No existing data - no need for cleanup of obsolete entries");
            return Collections.emptySet();
        } else {
            return known.stream()
                    .filter(requested::contains)
                    .collect(Collectors.toSet());
        }
    }
}
