package tech.molecules.deep;

import com.actelion.research.chem.StereoMolecule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import tech.molecules.leet.chem.ChemUtils;

/* loaded from: input_file:tech/molecules/deep/MolEncoder.class */
public class MolEncoder {
    private StereoMolecule m;
    int N = 32;
    int[] mappedAtomicNo = {1, 6, 7, 8, 9, 15, 16, 17, 35};
    Map<Integer, BitSet> encoding_AtomOneHot = new HashMap();
    Map<Integer, BitSet> encoding_BondTypeSimple = new HashMap();

    public MolEncoder() {
        init();
    }

    private void init() {
        for (int i = 0; i < this.mappedAtomicNo.length; i++) {
            BitSet bitSet = new BitSet();
            bitSet.set(i);
            this.encoding_AtomOneHot.put(Integer.valueOf(this.mappedAtomicNo[i]), bitSet);
        }
    }

    public void setMolecule(StereoMolecule stereoMolecule) {
        this.m = stereoMolecule;
    }

    public List<BitSet> encodeMolecule() {
        this.m.ensureHelperArrays(31);
        int[] iArr = new int[this.N];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        return encodeMolecule(iArr);
    }

    public List<BitSet> encodeMolecule(int[] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iArr.length; i++) {
            arrayList.add(encodeAtomWithConnections(i, iArr));
        }
        return arrayList;
    }

    public int getMoleculeEncodingLength() {
        return this.N * (getAtomEncodingLength() + (this.N * getConnectionEncodingLengthPerConnection()));
    }

    public int getAtomEncodingLength() {
        return 0 + this.mappedAtomicNo.length + 2;
    }

    public int getConnectionEncodingLengthPerConnection() {
        return 7;
    }

    public BitSet encodeAtomWithConnections(int i, int[] iArr) {
        if (iArr[i] >= this.m.getAtoms()) {
            return new BitSet(getAtomEncodingLength() + (this.N * getConnectionEncodingLengthPerConnection()));
        }
        BitSet bitSet = new BitSet(getAtomEncodingLength() + (this.N * getConnectionEncodingLengthPerConnection()));
        bitSet.or(encodeAtom(iArr[i]));
        for (int i2 = 0; i2 < this.N; i2++) {
            BitSet encodeConnection = encodeConnection(iArr[i], iArr[i2]);
            int atomEncodingLength = getAtomEncodingLength() + (i2 * getConnectionEncodingLengthPerConnection());
            for (int i3 = 0; i3 < getConnectionEncodingLengthPerConnection(); i3++) {
                bitSet.set(atomEncodingLength + i3, encodeConnection.get(i3));
            }
        }
        return bitSet;
    }

    public BitSet encodeAtom(int i) {
        BitSet bitSet = this.encoding_AtomOneHot.get(Integer.valueOf(this.m.getAtomicNo(i)));
        if (this.m.getAtomParity(i) == 1) {
            bitSet.set(this.mappedAtomicNo.length + 0);
        }
        if (this.m.getAtomParity(i) == 2) {
            bitSet.set(this.mappedAtomicNo.length + 1);
        }
        return bitSet;
    }

    public BitSet encodeConnection(int i, int i2) {
        if (i == i2 || this.m.getBond(i, i2) < 0) {
            return new BitSet(getConnectionEncodingLengthPerConnection());
        }
        int bond = this.m.getBond(i, i2);
        BitSet bitSet = new BitSet(7);
        if (this.m.getBondTypeSimple(bond) == 1) {
            bitSet.set(0);
        }
        if (this.m.getBondTypeSimple(bond) == 2) {
            bitSet.set(1);
        }
        if (this.m.getBondTypeSimple(bond) == 4) {
            bitSet.set(2);
        }
        if (this.m.getBondTypeSimple(bond) == 64) {
            bitSet.set(3);
        }
        if (this.m.getBondTypeSimple(bond) == 1) {
            if (this.m.getBondType(bond) == 257) {
                bitSet.set(4);
            }
            if (this.m.getBondType(bond) == 129) {
                bitSet.set(5);
            }
            if (this.m.getBondType(bond) == 386) {
                bitSet.set(6);
            }
        }
        return bitSet;
    }

    public static BitSet oneHot(int i, int i2) {
        BitSet bitSet = new BitSet(i2);
        bitSet.set(i);
        return bitSet;
    }

    public static BitSet multiHot(int[] iArr, int i) {
        BitSet bitSet = new BitSet(i);
        Arrays.stream(iArr).forEach(i2 -> {
            bitSet.set(i2);
        });
        return bitSet;
    }

    public static void print(List<BitSet> list, int i) {
        for (BitSet bitSet : list) {
            System.out.print("\n");
            for (int i2 = 0; i2 < i; i2++) {
                System.out.print(bitSet.get(i2) ? "1" : 0);
            }
        }
    }

    public static void main(String[] strArr) {
        StereoMolecule parseSmiles = ChemUtils.parseSmiles("C(=C/c1cccc2ccccc12)\\c1cc[nH+]cc1");
        MolEncoder molEncoder = new MolEncoder();
        molEncoder.setMolecule(parseSmiles);
        print(molEncoder.encodeMolecule(), molEncoder.getAtomEncodingLength() + (molEncoder.N * molEncoder.getConnectionEncodingLengthPerConnection()));
        System.out.println("");
    }
}
