/*
 * Decompiled with CFR 0.152.
 */
package org.rcsb.strucmotif.domain.structure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.rcsb.strucmotif.align.QuaternionAlignmentService;
import org.rcsb.strucmotif.config.StrucmotifConfig;
import org.rcsb.strucmotif.domain.Pair;
import org.rcsb.strucmotif.domain.Transformation;
import org.rcsb.strucmotif.domain.motif.AngleType;
import org.rcsb.strucmotif.domain.motif.DistanceType;
import org.rcsb.strucmotif.domain.motif.IndexSelectionResiduePairIdentifier;
import org.rcsb.strucmotif.domain.motif.ResiduePairDescriptor;
import org.rcsb.strucmotif.domain.motif.ResiduePairOccurrence;
import org.rcsb.strucmotif.domain.structure.IndexSelection;
import org.rcsb.strucmotif.domain.structure.LabelAtomId;
import org.rcsb.strucmotif.domain.structure.LabelSelection;
import org.rcsb.strucmotif.domain.structure.ResidueGrid;
import org.rcsb.strucmotif.domain.structure.ResidueType;
import org.rcsb.strucmotif.domain.structure.Structure;
import org.rcsb.strucmotif.math.Algebra;

public class ResidueGraph {
    private final Structure structure;
    private final Map<IndexSelection, Map<IndexSelection, Float>> backboneDistances;
    private final Map<IndexSelection, Map<IndexSelection, Float>> sideChainDistances;
    private final Map<IndexSelection, Map<IndexSelection, Float>> angles;
    private final int numberOfResidues;
    private final int numberOfPairings;
    private static final List<float[]> REFERENCE_BACKBONE = List.of(new float[]{-0.698f, 0.184f, 1.008f}, new float[]{0.525f, 0.109f, 0.2f}, new float[]{0.174f, -0.292f, -1.208f});
    private static final float[] REFERENCE_CB = new float[]{1.472f, -0.929f, 0.804f};
    private static final float[] REFERENCE_CENTROID = new float[3];

    public ResidueGraph(Structure structure, List<LabelSelection> labelSelections, List<Map<LabelAtomId, float[]>> residues, StrucmotifConfig strucmotifConfig) {
        this.structure = structure;
        this.backboneDistances = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        this.sideChainDistances = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        this.angles = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        LinkedHashMap<IndexSelection, float[]> normalVectorMap = new LinkedHashMap<IndexSelection, float[]>();
        LinkedHashMap<IndexSelection, float[]> backboneVectors = new LinkedHashMap<IndexSelection, float[]>();
        LinkedHashMap<IndexSelection, float[]> sideChainVectors = new LinkedHashMap<IndexSelection, float[]>();
        ArrayList<IndexSelection> indexSelections = new ArrayList<IndexSelection>();
        for (int i = 0; i < labelSelections.size(); ++i) {
            LabelSelection labelSelection = labelSelections.get(i);
            Map<LabelAtomId, float[]> residue = residues.get(i);
            int residueIndex = structure.getResidueIndex(labelSelection.getLabelAsymId(), labelSelection.getLabelSeqId());
            IndexSelection indexSelection = new IndexSelection(labelSelection.getStructOperId(), residueIndex);
            ResidueType residueType = structure.getResidueType(residueIndex);
            float[] backbone = ResidueGraph.getBackboneCoords(residue);
            float[] sideChain = residueType == ResidueType.GLYCINE ? ResidueGraph.getVirtualCB(residue) : ResidueGraph.getSideChainCoords(residue);
            if (backbone == null || sideChain == null) continue;
            indexSelections.add(indexSelection);
            backboneVectors.put(indexSelection, backbone);
            sideChainVectors.put(indexSelection, sideChain);
            normalVectorMap.put(indexSelection, ResidueGraph.normalVector(backbone, sideChain));
        }
        Map<String, String[]> assemblyMap = structure.getAssemblies();
        if (strucmotifConfig.isUndefinedAssemblies() && assemblyMap.isEmpty()) {
            assemblyMap.put(strucmotifConfig.getUndefinedAssemblyIdentifier(), (String[])structure.getLabelSelections().stream().map(LabelSelection::getLabelAsymId).distinct().map(c -> c + "_1").toArray(String[]::new));
        }
        this.numberOfResidues = backboneVectors.size();
        this.numberOfPairings = this.fillResidueGrid(backboneVectors, sideChainVectors, normalVectorMap, indexSelections, strucmotifConfig.getSquaredDistanceCutoff(), ResidueGraphOptions.all(), assemblyMap);
    }

    public ResidueGraph(Structure structure, StrucmotifConfig strucmotifConfig, ResidueGraphOptions options) {
        this.structure = structure;
        this.backboneDistances = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        this.sideChainDistances = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        this.angles = new HashMap<IndexSelection, Map<IndexSelection, Float>>();
        Map<String, List<LabelSelection>> chainMap = structure.getLabelSelections().stream().collect(Collectors.groupingBy(LabelSelection::getLabelAsymId));
        Map<String, String[]> assemblyMap = structure.getAssemblies();
        if (strucmotifConfig.isUndefinedAssemblies() && assemblyMap.isEmpty()) {
            assemblyMap.put(strucmotifConfig.getUndefinedAssemblyIdentifier(), (String[])chainMap.keySet().stream().map(c -> c + "_1").toArray(String[]::new));
        }
        ArrayList<float[]> originalBackboneVectors = new ArrayList<float[]>();
        ArrayList<float[]> originalSideChainVectors = new ArrayList<float[]>();
        for (int i = 0; i < structure.getResidueCount(); ++i) {
            ResidueType residueType = structure.getResidueType(i);
            Map<LabelAtomId, float[]> residue = structure.manifestResidue(i);
            originalBackboneVectors.add(ResidueGraph.getBackboneCoords(residue));
            if (residueType == ResidueType.GLYCINE) {
                originalSideChainVectors.add(ResidueGraph.getVirtualCB(residue));
                continue;
            }
            originalSideChainVectors.add(ResidueGraph.getSideChainCoords(residue));
        }
        List assemblyInformation = assemblyMap.values().stream().flatMap(Arrays::stream).distinct().collect(Collectors.toList());
        ArrayList<IndexSelection> residueKeys = new ArrayList<IndexSelection>();
        LinkedHashMap<IndexSelection, float[]> normalVectorMap = new LinkedHashMap<IndexSelection, float[]>();
        LinkedHashMap<IndexSelection, float[]> transformedBackboneVectors = new LinkedHashMap<IndexSelection, float[]>();
        LinkedHashMap<IndexSelection, float[]> transformedSideChainVectors = new LinkedHashMap<IndexSelection, float[]>();
        for (String a : assemblyInformation) {
            String[] split = a.split("_");
            String labelAsymId = split[0];
            String oper = split[1];
            Transformation transformation = structure.getTransformation(oper);
            if (!chainMap.containsKey(labelAsymId)) continue;
            for (LabelSelection labelSelection : chainMap.get(labelAsymId)) {
                int residueIndex = structure.getResidueIndex(labelSelection.getLabelAsymId(), labelSelection.getLabelSeqId());
                float[] originalBackbone = (float[])originalBackboneVectors.get(residueIndex);
                float[] originalSideChain = (float[])originalSideChainVectors.get(residueIndex);
                if (originalBackbone == null || originalSideChain == null) continue;
                IndexSelection key = new IndexSelection(oper, residueIndex);
                residueKeys.add(key);
                float[] backbone = new float[3];
                float[] sideChain = new float[3];
                transformation.transform(backbone, originalBackbone);
                transformedBackboneVectors.put(key, backbone);
                transformation.transform(sideChain, originalSideChain);
                transformedSideChainVectors.put(key, sideChain);
                normalVectorMap.put(key, ResidueGraph.normalVector(backbone, sideChain));
            }
        }
        this.numberOfResidues = transformedBackboneVectors.size();
        this.numberOfPairings = this.fillResidueGrid(transformedBackboneVectors, transformedSideChainVectors, normalVectorMap, residueKeys, strucmotifConfig.getSquaredDistanceCutoff(), options, assemblyMap);
    }

    private int fillResidueGrid(Map<IndexSelection, float[]> backboneVectors, Map<IndexSelection, float[]> sideChainVectors, Map<IndexSelection, float[]> normalVectorMap, List<IndexSelection> indexSelections, float squaredCutoff, ResidueGraphOptions options, Map<String, String[]> assemblies) {
        ResidueGrid residueGrid = new ResidueGrid(new ArrayList<float[]>(backboneVectors.values()), squaredCutoff);
        Map<String, List> assemblyMap = assemblies.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> Arrays.asList((String[])e.getValue())));
        ResidueGraphMode mode = options.mode;
        HashSet<String> acceptedChains = new HashSet<String>();
        HashSet<String> acceptedOperators = new HashSet<String>();
        if (mode == ResidueGraphMode.DEPOSITED || mode == ResidueGraphMode.DEPOSITED_AND_CONTACTS) {
            for (List chainExprs : assemblyMap.values()) {
                for (String chainExpr : chainExprs) {
                    String chain = chainExpr.split("_")[0];
                    if (acceptedChains.contains(chain)) continue;
                    acceptedChains.add(chain);
                    acceptedOperators.add(chainExpr);
                }
            }
        }
        String requestAssemblyIdentifier = options.assemblyIdentifier;
        List requestChains = assemblyMap.get(requestAssemblyIdentifier);
        List<LabelSelection> labelSelections = this.structure.getLabelSelections();
        int size = 0;
        block7: for (ResidueGrid.ResidueContact residueContact : residueGrid.getIndicesContacts()) {
            if (residueContact.getI() >= residueContact.getJ()) continue;
            IndexSelection residueKey1 = indexSelections.get(residueContact.getI());
            String chainExpr1 = labelSelections.get(residueKey1.getIndex()).getLabelAsymId() + "_" + residueKey1.getStructOperId();
            IndexSelection residueKey2 = indexSelections.get(residueContact.getJ());
            String chainExpr2 = labelSelections.get(residueKey2.getIndex()).getLabelAsymId() + "_" + residueKey2.getStructOperId();
            switch (mode) {
                case DEPOSITED: {
                    if (acceptedOperators.contains(chainExpr1) && acceptedOperators.contains(chainExpr2)) break;
                    continue block7;
                }
                case DEPOSITED_AND_CONTACTS: {
                    if (acceptedOperators.contains(chainExpr1)) break;
                    continue block7;
                }
                case ASSEMBLY: {
                    if (requestChains.contains(chainExpr1) && requestChains.contains(chainExpr2)) break;
                    continue block7;
                }
            }
            if (assemblyMap.values().stream().noneMatch(opers -> opers.contains(chainExpr1) && opers.contains(chainExpr2))) continue;
            float[] normalVector1 = normalVectorMap.get(residueKey1);
            float[] normalVector2 = normalVectorMap.get(residueKey2);
            float[] sideChainCoordinates1 = sideChainVectors.get(residueKey1);
            float[] sideChainCoordinates2 = sideChainVectors.get(residueKey2);
            if (sideChainCoordinates1 == null || sideChainCoordinates2 == null) continue;
            Map innerPolymerAnchorMap = this.backboneDistances.computeIfAbsent(residueKey1, key -> new HashMap());
            innerPolymerAnchorMap.put(residueKey2, Float.valueOf(residueContact.getDistance()));
            Map innerInteractionCenterMap = this.sideChainDistances.computeIfAbsent(residueKey1, key -> new HashMap());
            innerInteractionCenterMap.put(residueKey2, Float.valueOf(Algebra.distance3d(sideChainCoordinates1, sideChainCoordinates2)));
            Map innerAngleMap = this.angles.computeIfAbsent(residueKey1, key -> new HashMap());
            innerAngleMap.put(residueKey2, Float.valueOf(ResidueGraph.angle(normalVector1, normalVector2)));
            ++size;
        }
        return size;
    }

    static float[] getVirtualCB(Map<LabelAtomId, float[]> residue) {
        float[] n = residue.get((Object)LabelAtomId.N);
        float[] ca = residue.get((Object)LabelAtomId.CA);
        float[] c = residue.get((Object)LabelAtomId.C);
        if (n == null || ca == null || c == null) {
            return null;
        }
        List<float[]> coords = List.of(n, ca, c);
        float[] v = Algebra.centroid3d(coords);
        Transformation transformation = QuaternionAlignmentService.align(coords, v, REFERENCE_BACKBONE, REFERENCE_CENTROID).getFirst();
        Algebra.multiply4d(v, transformation.getTransformationMatrix(), REFERENCE_CB);
        return v;
    }

    private static float[] getBackboneCoords(Map<LabelAtomId, float[]> residue) {
        if (residue.containsKey((Object)LabelAtomId.CA)) {
            return residue.get((Object)LabelAtomId.CA);
        }
        if (residue.containsKey((Object)LabelAtomId.C4_PRIME)) {
            return residue.get((Object)LabelAtomId.C4_PRIME);
        }
        return null;
    }

    private static float[] getSideChainCoords(Map<LabelAtomId, float[]> residue) {
        if (residue.containsKey((Object)LabelAtomId.CB)) {
            return residue.get((Object)LabelAtomId.CB);
        }
        if (residue.containsKey((Object)LabelAtomId.C1_PRIME)) {
            return residue.get((Object)LabelAtomId.C1_PRIME);
        }
        return null;
    }

    private static float[] normalVector(float[] a, float[] b) {
        float[] ba = new float[3];
        Algebra.subtract3d(ba, b, a);
        Algebra.normalize3d(ba, ba);
        return ba;
    }

    static float angle(float[] v1, float[] v2) {
        float vDot = Algebra.dotProduct3d(v1, v2);
        return (float)Math.toDegrees(Math.acos(Algebra.capToInterval(-1.0f, vDot, 1.0f)));
    }

    public float getBackboneDistance(IndexSelection residue1, IndexSelection residue2) {
        return this.tryGet(this.backboneDistances, residue1, residue2);
    }

    public float getSideChainDistance(IndexSelection residue1, IndexSelection residue2) {
        return this.tryGet(this.sideChainDistances, residue1, residue2);
    }

    public float getAngle(IndexSelection residue1, IndexSelection residue2) {
        return this.tryGet(this.angles, residue1, residue2);
    }

    private float tryGet(Map<IndexSelection, Map<IndexSelection, Float>> map, IndexSelection i1, IndexSelection i2) {
        Map<IndexSelection, Float> m;
        if (map.containsKey(i1) && (m = map.get(i1)).containsKey(i2)) {
            return m.get(i2).floatValue();
        }
        if (map.containsKey(i2) && (m = map.get(i2)).containsKey(i1)) {
            return m.get(i1).floatValue();
        }
        return Float.MAX_VALUE;
    }

    public int getNumberOfPairings() {
        return this.numberOfPairings;
    }

    public Stream<Pair<IndexSelection, IndexSelection>> pairingsSequential() {
        return this.backboneDistances.keySet().stream().flatMap(e -> this.pairs((IndexSelection)e, false));
    }

    public Stream<Pair<IndexSelection, IndexSelection>> pairingsParallel() {
        return this.backboneDistances.keySet().parallelStream().flatMap(e -> this.pairs((IndexSelection)e, true));
    }

    public Stream<ResiduePairOccurrence> residuePairOccurrencesParallel() {
        return this.pairingsParallel().map(this::createMotifOccurrence);
    }

    public Stream<ResiduePairOccurrence> residuePairOccurrencesSequential() {
        return this.pairingsSequential().map(this::createMotifOccurrence);
    }

    private ResiduePairOccurrence createMotifOccurrence(Pair<IndexSelection, IndexSelection> pair) {
        IndexSelection indexSelection1 = pair.getFirst();
        IndexSelection indexSelection2 = pair.getSecond();
        ResidueType residueType1 = this.structure.getResidueType(indexSelection1.getIndex());
        ResidueType residueType2 = this.structure.getResidueType(indexSelection2.getIndex());
        if (residueType1.getInternalCode().compareTo(residueType2.getInternalCode()) > 0) {
            return this.createMotifOccurrence(new Pair<IndexSelection, IndexSelection>(indexSelection2, indexSelection1));
        }
        DistanceType backboneDistance = DistanceType.ofDistance(this.getBackboneDistance(indexSelection1, indexSelection2));
        DistanceType sideChainDistance = DistanceType.ofDistance(this.getSideChainDistance(indexSelection1, indexSelection2));
        AngleType angle = AngleType.ofAngle(this.getAngle(indexSelection1, indexSelection2));
        ResiduePairDescriptor residuePairDescriptor = new ResiduePairDescriptor(residueType1, residueType2, backboneDistance, sideChainDistance, angle);
        IndexSelectionResiduePairIdentifier residuePairIdentifier = new IndexSelectionResiduePairIdentifier(indexSelection1, indexSelection2);
        return new ResiduePairOccurrence(residuePairDescriptor, residuePairIdentifier);
    }

    private Stream<Pair<IndexSelection, IndexSelection>> pairs(IndexSelection residue1, boolean parallel) {
        Map<IndexSelection, Float> map = this.backboneDistances.get(residue1);
        if (map == null) {
            return Stream.empty();
        }
        if (parallel) {
            return map.keySet().parallelStream().map(residue2 -> new Pair<IndexSelection, IndexSelection>(residue1, (IndexSelection)residue2));
        }
        return map.keySet().stream().map(residue2 -> new Pair<IndexSelection, IndexSelection>(residue1, (IndexSelection)residue2));
    }

    public int getNumberOfResidues() {
        return this.numberOfResidues;
    }

    public static class ResidueGraphOptions {
        final ResidueGraphMode mode;
        final String assemblyIdentifier;

        private ResidueGraphOptions(ResidueGraphMode mode, String assemblyIdentifier) {
            this.mode = mode;
            this.assemblyIdentifier = assemblyIdentifier;
        }

        public static ResidueGraphOptions deposited() {
            return new ResidueGraphOptions(ResidueGraphMode.DEPOSITED, null);
        }

        public static ResidueGraphOptions depositedAndContacts() {
            return new ResidueGraphOptions(ResidueGraphMode.DEPOSITED_AND_CONTACTS, null);
        }

        public static ResidueGraphOptions assembly(String assemblyIdentifier) {
            return new ResidueGraphOptions(ResidueGraphMode.ASSEMBLY, assemblyIdentifier);
        }

        public static ResidueGraphOptions all() {
            return new ResidueGraphOptions(ResidueGraphMode.ALL, null);
        }
    }

    public static enum ResidueGraphMode {
        DEPOSITED,
        DEPOSITED_AND_CONTACTS,
        ASSEMBLY,
        ALL;

    }
}

