package tech.molecules.deep.conformers;

import com.actelion.research.chem.StereoMolecule;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:tech/molecules/deep/conformers/SpaceSampler.class */
public class SpaceSampler {
    private StereoMolecule mi;
    private int maxAtoms;
    private double cubeLengthAnstrom;
    private int gridSize;
    private double varianceA;
    private double varianceB;
    private double offsetA;
    private Map<Integer, Integer> atomEncoding;
    private Map<Integer, Integer> bondEncoding;

    /* loaded from: input_file:tech/molecules/deep/conformers/SpaceSampler$SampledSpace.class */
    public static class SampledSpace {
        public final double[][][][] x;
        public final int[] structureInfo;
        public final int[][] bondsType;

        public SampledSpace(double[][][][] dArr, int[] iArr, int[][] iArr2) {
            this.x = dArr;
            this.structureInfo = iArr;
            this.bondsType = iArr2;
        }
    }

    /* loaded from: input_file:tech/molecules/deep/conformers/SpaceSampler$SampledSpaceOneHot.class */
    public static class SampledSpaceOneHot {
        public final double[][][][] x;
        public final double[][][][] x_target;
        public final boolean[][] structureInfo;
        public final boolean[][][] bondsType;

        public SampledSpaceOneHot(double[][][][] dArr, double[][][][] dArr2, int[] iArr, int[][] iArr2) {
            this.x = dArr;
            this.x_target = dArr2;
            this.structureInfo = new boolean[iArr.length][10];
            this.bondsType = new boolean[iArr2.length][iArr2.length][7];
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] != 0) {
                    this.structureInfo[i][iArr[i]] = true;
                }
            }
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                for (int i3 = 0; i3 < iArr2[0].length; i3++) {
                    if (iArr2[i2][i3] != 0) {
                        this.bondsType[i2][i3][iArr2[i2][i3]] = true;
                    }
                }
            }
        }
    }

    /* loaded from: input_file:tech/molecules/deep/conformers/SpaceSampler$SampledSpaceOneHot2.class */
    public static class SampledSpaceOneHot2 {
        public final double[][][][] x;
        public final double[][][][] x_target;
        public final int[][] structureInfo;
        public final int[][][] bondsType;

        public SampledSpaceOneHot2(double[][][][] dArr, double[][][][] dArr2, int[] iArr, int[][] iArr2) {
            this.x = dArr;
            this.x_target = dArr2;
            this.structureInfo = new int[iArr.length][9];
            this.bondsType = new int[iArr2.length][iArr2.length][7];
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] != 0) {
                    this.structureInfo[i][iArr[i]] = 1;
                }
            }
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                for (int i3 = 0; i3 < iArr2[0].length; i3++) {
                    if (iArr2[i2][i3] != 0) {
                        this.bondsType[i2][i3][iArr2[i2][i3]] = 1;
                    }
                }
            }
        }
    }

    public SpaceSampler(StereoMolecule stereoMolecule, int i, double d, int i2, double d2, double d3, double d4) {
        this.mi = stereoMolecule;
        this.maxAtoms = i;
        this.cubeLengthAnstrom = d;
        this.gridSize = i2;
        this.varianceA = d2;
        this.varianceB = d3;
        this.mi.center();
        Random random = new Random();
        double nextDouble = random.nextDouble();
        double nextDouble2 = random.nextDouble();
        double nextDouble3 = random.nextDouble();
        for (int i3 = 0; i3 < stereoMolecule.getAllAtoms(); i3++) {
            double[] rotate3D = rotate3D(stereoMolecule.getAtomX(i3), stereoMolecule.getAtomY(i3), stereoMolecule.getAtomZ(i3), nextDouble, nextDouble2, nextDouble3);
            this.mi.setAtomX(i3, rotate3D[0]);
            this.mi.setAtomY(i3, rotate3D[1]);
            this.mi.setAtomZ(i3, rotate3D[2]);
        }
        initAtomEncoding();
        initBondsEncoding();
    }

    private void initAtomEncoding() {
        this.atomEncoding = new HashMap();
        this.atomEncoding.put(1, 1);
        this.atomEncoding.put(6, 2);
        this.atomEncoding.put(7, 3);
        this.atomEncoding.put(8, 4);
        this.atomEncoding.put(9, 5);
        this.atomEncoding.put(15, 6);
        this.atomEncoding.put(16, 7);
        this.atomEncoding.put(17, 8);
        this.atomEncoding.put(35, 9);
    }

    private void initBondsEncoding() {
        this.bondEncoding = new HashMap();
        this.bondEncoding.put(1, 1);
        this.bondEncoding.put(64, 2);
        this.bondEncoding.put(2, 3);
        this.bondEncoding.put(4, 4);
        this.bondEncoding.put(257, 5);
        this.bondEncoding.put(129, 6);
    }

    private int encodeAtom(int i) throws Exception {
        Integer num = this.atomEncoding.get(Integer.valueOf(i));
        if (num == null) {
            throw new Exception("Unknown element: " + i);
        }
        return num.intValue();
    }

    private int encodeBond(int i) throws Exception {
        Integer num = this.bondEncoding.get(Integer.valueOf(i));
        if (num == null) {
            throw new Exception("Unknown bond type: " + i);
        }
        return num.intValue();
    }

    public SampledSpaceOneHot sampleSpace(Random random) throws Exception {
        this.mi.ensureHelperArrays(31);
        double[][][][] dArr = new double[this.gridSize][this.gridSize][this.gridSize][this.maxAtoms];
        double[][][][] dArr2 = new double[this.gridSize][this.gridSize][this.gridSize][this.maxAtoms];
        int[] encodeAtoms = encodeAtoms();
        int[][] encodeBonds = encodeBonds();
        double d = Double.POSITIVE_INFINITY;
        double d2 = Double.NEGATIVE_INFINITY;
        double d3 = Double.POSITIVE_INFINITY;
        double d4 = Double.NEGATIVE_INFINITY;
        double d5 = Double.POSITIVE_INFINITY;
        double d6 = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.mi.getAtoms(); i++) {
            d = Math.min(d, this.mi.getAtomX(i));
            d2 = Math.max(d2, this.mi.getAtomZ(i));
            d3 = Math.min(d3, this.mi.getAtomY(i));
            d4 = Math.max(d4, this.mi.getAtomY(i));
            d5 = Math.min(d5, this.mi.getAtomZ(i));
            d6 = Math.max(d6, this.mi.getAtomZ(i));
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(Double.valueOf(d));
        arrayList.add(Double.valueOf(d2));
        arrayList.add(Double.valueOf(d3));
        arrayList.add(Double.valueOf(d4));
        arrayList.add(Double.valueOf(d5));
        arrayList.add(Double.valueOf(d6));
        if (arrayList.stream().mapToDouble(d7 -> {
            return Math.abs(d7.doubleValue());
        }).anyMatch(d8 -> {
            return d8 > ((double) this.gridSize);
        })) {
            throw new Exception("Too big: " + arrayList.stream().mapToDouble(d9 -> {
                return Math.abs(d9.doubleValue());
            }).max().getAsDouble());
        }
        double d10 = (-0.5d) * this.cubeLengthAnstrom;
        double d11 = (-0.5d) * this.cubeLengthAnstrom;
        double d12 = (-0.5d) * this.cubeLengthAnstrom;
        double d13 = this.cubeLengthAnstrom / this.gridSize;
        for (int i2 = 0; i2 < this.mi.getAllAtoms(); i2++) {
            double d14 = 0.0d;
            double d15 = 0.0d;
            double atomX = this.mi.getAtomX(i2);
            double atomY = this.mi.getAtomY(i2);
            double atomZ = this.mi.getAtomZ(i2);
            for (int i3 = 0; i3 < this.gridSize; i3++) {
                for (int i4 = 0; i4 < this.gridSize; i4++) {
                    for (int i5 = 0; i5 < this.gridSize; i5++) {
                        double d16 = d10 + (i3 * d13);
                        double d17 = d11 + (i4 * d13);
                        double d18 = d12 + (i5 * d13);
                        double nextDouble = d16 + (random.nextDouble() * this.varianceA * this.offsetA);
                        double nextDouble2 = d17 + (random.nextDouble() * this.varianceA * this.offsetA);
                        double nextDouble3 = d18 + (random.nextDouble() * this.varianceA * this.offsetA);
                        double d19 = atomX - nextDouble;
                        double d20 = atomY - nextDouble2;
                        double d21 = atomZ - nextDouble3;
                        double d22 = atomX - d16;
                        double d23 = atomY - d17;
                        double d24 = atomZ - d18;
                        double sqrt = Math.sqrt((d19 * d19) + (d20 * d20) + (d21 * d21));
                        double sqrt2 = Math.sqrt((d22 * d22) + (d23 * d23) + (d24 * d24));
                        if (sqrt < this.varianceA) {
                            dArr[i3][i4][i5][i2] = 1.0d;
                            d14 += 1.0d;
                        } else {
                            double max = Math.max(0.0d, 1.0d - (sqrt - this.varianceA));
                            dArr[i3][i4][i5][i2] = max;
                            d14 += max;
                        }
                        if (sqrt2 < this.varianceB) {
                            dArr2[i3][i4][i5][i2] = 1.0d;
                            d15 += 1.0d;
                        } else {
                            double max2 = Math.max(0.0d, 1.0d - (sqrt - this.varianceB));
                            dArr[i3][i4][i5][i2] = max2;
                            d14 += max2;
                        }
                    }
                }
            }
        }
        return new SampledSpaceOneHot(dArr, dArr2, encodeAtoms, encodeBonds);
    }

    private int[] encodeAtoms() throws Exception {
        int[] iArr = new int[this.maxAtoms];
        for (int i = 0; i < this.mi.getAllAtoms(); i++) {
            iArr[i] = encodeAtom(this.mi.getAtomicNo(i));
        }
        return iArr;
    }

    private int[][] encodeBonds() throws Exception {
        int[][] iArr = new int[this.maxAtoms][this.maxAtoms];
        for (int i = 0; i < this.mi.getAllBonds(); i++) {
            int bondAtom = this.mi.getBondAtom(0, i);
            int bondAtom2 = this.mi.getBondAtom(1, i);
            int encodeBond = encodeBond(this.mi.getBondType(i));
            iArr[bondAtom][bondAtom2] = encodeBond;
            iArr[bondAtom2][bondAtom] = encodeBond;
        }
        return iArr;
    }

    public static double[] rotate3D(double d, double d2, double d3, double d4, double d5, double d6) {
        double radians = Math.toRadians(d4);
        double radians2 = Math.toRadians(d5);
        double radians3 = Math.toRadians(d6);
        return new double[]{(((d * Math.cos(radians2)) * Math.cos(radians3)) - (d2 * ((Math.cos(radians) * Math.sin(radians3)) - ((Math.sin(radians) * Math.sin(radians2)) * Math.cos(radians3))))) + (d3 * ((Math.sin(radians) * Math.sin(radians3)) + (Math.cos(radians) * Math.sin(radians2) * Math.cos(radians3)))), (((d * Math.cos(radians2)) * Math.sin(radians3)) + (d2 * ((Math.cos(radians) * Math.cos(radians3)) + ((Math.sin(radians) * Math.sin(radians2)) * Math.sin(radians3))))) - (d3 * ((Math.sin(radians) * Math.cos(radians3)) - ((Math.cos(radians) * Math.sin(radians2)) * Math.sin(radians3)))), ((-d) * Math.sin(radians2)) + (d2 * Math.sin(radians) * Math.cos(radians2)) + (d3 * Math.cos(radians) * Math.cos(radians2))};
    }
}
