package com.omega.engine.database;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MathUtils;
import com.omega.engine.nn.data.DataSet;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:com/omega/engine/database/DataLoader.class */
public class DataLoader {
    private DataSet trainData;
    private int batchSize;
    private float validSize;
    private List<Integer> trainIndex;
    private List<Integer> validIndex;
    private int channel;
    private int height;
    private int width;
    private int labelSize;

    public DataLoader(DataSet dataSet, int i, float f) {
        this.validSize = 0.2f;
        this.trainData = dataSet;
        this.batchSize = i;
        this.validSize = f;
        this.channel = dataSet.channel;
        this.height = dataSet.height;
        this.width = dataSet.width;
        this.labelSize = dataSet.labelSize;
        initDataset(f);
    }

    public void loadTrainData(int[] iArr, Tensor tensor, Tensor tensor2) {
        for (int i = 0; i < iArr.length; i++) {
            int i2 = iArr[i];
            System.arraycopy(tensor.data, i2 * this.channel * this.height * this.width, tensor.data, i * this.channel * this.height * this.width, this.channel * this.height * this.width);
            System.arraycopy(tensor2.data, i2 * this.labelSize, tensor2.data, i * this.labelSize, this.labelSize);
        }
    }

    public int[][] getTrainIndex() {
        int[][] iArr = new int[new BigDecimal(this.trainIndex.size()).divide(new BigDecimal(this.batchSize), 0, 0).intValue()][this.batchSize];
        Collections.shuffle(this.trainIndex);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                if ((i * this.batchSize) + i2 >= this.trainIndex.size()) {
                    iArr[i][i2] = this.trainIndex.get((0 * this.batchSize) + i2).intValue();
                } else {
                    iArr[i][i2] = this.trainIndex.get((i * this.batchSize) + i2).intValue();
                }
            }
        }
        return iArr;
    }

    public int[][] getValidIndex() {
        int[][] iArr = new int[new BigDecimal(this.validIndex.size()).divide(new BigDecimal(this.batchSize), 0, 0).intValue()][this.batchSize];
        Collections.shuffle(this.validIndex);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                if ((i * this.batchSize) + i2 >= this.validIndex.size()) {
                    iArr[i][i2] = this.validIndex.get((0 * this.batchSize) + i2).intValue();
                } else {
                    iArr[i][i2] = this.validIndex.get((i * this.batchSize) + i2).intValue();
                }
            }
        }
        return iArr;
    }

    private void initDataset(float f) {
        try {
            int i = this.trainData.number;
            new BigDecimal(i).multiply(new BigDecimal(f));
            int i2 = i - 5;
            Integer[] randomInts = MathUtils.randomInts(i);
            this.trainIndex = new ArrayList(5);
            this.validIndex = new ArrayList(i2);
            loadDataset(randomInts, this.trainIndex, this.validIndex);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void loadDataset(Integer[] numArr, List<Integer> list, List<Integer> list2) {
        for (int i = 0; i < list.size(); i++) {
            list.set(i, numArr[i]);
        }
        for (int i2 = 0; i2 < list2.size(); i2++) {
            list.set(i2, numArr[i2 + (list.size() - 1)]);
        }
    }
}
