package tech.molecules.deep.conformers;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import org.jetbrains.bio.npy.NpyFile;
import tech.molecules.deep.conformers.SpaceSampler;

/* loaded from: input_file:tech/molecules/deep/conformers/NpyExporter.class */
public class NpyExporter {
    public static void exportSamplesSpacesBatch(List<SpaceSampler.SampledSpaceOneHot> list, String str, String str2) throws IOException {
        Path path = Paths.get(str, new String[0]);
        if (!path.toFile().exists()) {
            path.toFile().mkdirs();
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            double[][][][] dArr = list.get(i).x;
            double[][][][] dArr2 = list.get(i).x_target;
            boolean[][] zArr = list.get(i).structureInfo;
            boolean[][][] zArr2 = list.get(i).bondsType;
            arrayList.add(dArr);
            arrayList2.add(dArr2);
            arrayList3.add(zArr);
            arrayList4.add(zArr2);
        }
        exportToNpy_ListX(arrayList, str, str2 + "_x.npy");
        exportToNpy_ListX(arrayList2, str, str2 + "_target.npy");
        exportToNpy_ListAtomInfo(arrayList3, str, str2 + "_structureInfo.npy");
        exportToNpy_ListBondInfo(arrayList4, str, str2 + "_bondsType.npy");
    }

    public static void exportToNpy_ListX(List<double[][][][]> list, String str, String str2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(Integer.valueOf(list.size()));
        arrayList2.add(Integer.valueOf(list.get(0).length));
        arrayList2.add(Integer.valueOf(list.get(0)[0].length));
        arrayList2.add(Integer.valueOf(list.get(0)[0][0].length));
        arrayList2.add(Integer.valueOf(list.get(0)[0][0][0].length));
        for (double[][][][] dArr : list) {
            for (double[][][] dArr2 : dArr) {
                for (double[][] dArr3 : dArr2) {
                    for (double[] dArr4 : dArr3) {
                        for (double d : dArr4) {
                            arrayList.add(Float.valueOf((float) d));
                        }
                    }
                }
            }
        }
        NpyFile.write(Paths.get(str, str2), convertListToArray_Float(arrayList), convertListToArray_Int(arrayList2));
    }

    public static void exportToNpy_ListAtomInfo(List<boolean[][]> list, String str, String str2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(Integer.valueOf(list.size()));
        arrayList2.add(Integer.valueOf(list.get(0).length));
        arrayList2.add(Integer.valueOf(list.get(0)[0].length));
        for (boolean[][] zArr : list) {
            for (boolean[] zArr2 : zArr) {
                for (boolean z : zArr2) {
                    arrayList.add(Boolean.valueOf(z));
                }
            }
        }
        NpyFile.write(Paths.get(str, str2), convertListToArray_Bool(arrayList), convertListToArray_Int(arrayList2));
    }

    public static void exportToNpy_ListBondInfo(List<boolean[][][]> list, String str, String str2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(Integer.valueOf(list.size()));
        arrayList2.add(Integer.valueOf(list.get(0).length));
        arrayList2.add(Integer.valueOf(list.get(0)[0].length));
        arrayList2.add(Integer.valueOf(list.get(0)[0][0].length));
        for (boolean[][][] zArr : list) {
            for (boolean[][] zArr2 : zArr) {
                for (boolean[] zArr3 : zArr2) {
                    for (boolean z : zArr3) {
                        arrayList.add(Boolean.valueOf(z));
                    }
                }
            }
        }
        NpyFile.write(Paths.get(str, str2), convertListToArray_Bool(arrayList), convertListToArray_Int(arrayList2));
    }

    public static float[] convertListToArray_Float(List<Float> list) {
        float[] fArr = new float[list.size()];
        for (int i = 0; i < list.size(); i++) {
            fArr[i] = list.get(i).floatValue();
        }
        return fArr;
    }

    public static int[] convertListToArray_Int(List<Integer> list) {
        int[] iArr = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            iArr[i] = list.get(i).intValue();
        }
        return iArr;
    }

    public static boolean[] convertListToArray_Bool(List<Boolean> list) {
        boolean[] zArr = new boolean[list.size()];
        for (int i = 0; i < list.size(); i++) {
            zArr[i] = list.get(i).booleanValue();
        }
        return zArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[][][], double[][][][]] */
    public static void main(String[] strArr) {
        write(Paths.get("data.npy", new String[0]), new double[][][]{new double[][]{new double[]{new double[]{1.1d, 1.2d}, new double[]{1.3d, 1.4d}}, new double[]{new double[]{2.1d, 2.2d}, new double[]{2.3d, 2.4d}}}, new double[][]{new double[]{new double[]{3.1d, 3.2d}, new double[]{3.3d, 3.4d}}, new double[]{new double[]{4.1d, 4.2d}, new double[]{4.3d, 4.4d}}}});
    }

    public static void write(Path path, double[][][][] dArr) {
        NpyFile.write(path, unroll(dArr, getTotalSize(dArr)), getDimensions(dArr));
    }

    private static int getTotalSize(double[][][][] dArr) {
        int i = 0;
        for (double[][][] dArr2 : dArr) {
            for (double[][] dArr3 : dArr2) {
                for (double[] dArr4 : dArr3) {
                    i += dArr4.length;
                }
            }
        }
        return i;
    }

    private static float[] unroll(double[][][][] dArr, int i) {
        float[] fArr = new float[i];
        int i2 = 0;
        for (double[][][] dArr2 : dArr) {
            for (double[][] dArr3 : dArr2) {
                for (double[] dArr4 : dArr3) {
                    for (double d : dArr4) {
                        fArr[i2] = (float) d;
                        i2++;
                    }
                }
            }
        }
        return fArr;
    }

    private static int[] getDimensions(double[][][][] dArr) {
        return new int[]{dArr.length, dArr[0].length, dArr[0][0].length, dArr[0][0][0].length};
    }
}
