package ml.dmlc.xgboost4j.java;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/DMatrixTest.class */
public class DMatrixTest {
    @Test
    public void testCreateFromDataIterator() throws XGBoostError {
        ArrayList arrayList = new ArrayList();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < 3000; i++) {
            LabeledPoint labeledPoint = new LabeledPoint(0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3.0f, 4.0f, 5.0f});
            linkedList.add(labeledPoint);
            arrayList.add(Float.valueOf(labeledPoint.label()));
        }
        float[] label = new DMatrix(linkedList.iterator(), (String) null).getLabel();
        for (int i2 = 0; i2 < label.length; i2++) {
            TestCase.assertTrue(((Float) arrayList.get(i2)).floatValue() == label[i2]);
        }
    }

    @Test
    public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError {
        ArrayList arrayList = new ArrayList();
        LinkedList linkedList = new LinkedList();
        int i = 4;
        for (int i2 = 0; i2 < 3000; i2++) {
            if (i2 % 10 == 1) {
                i = 5;
            }
            LabeledPoint labeledPoint = new LabeledPoint(0.1f + i2, i, new int[]{0, 2, 3}, new float[]{3.0f, 4.0f, 5.0f});
            linkedList.add(labeledPoint);
            arrayList.add(Float.valueOf(labeledPoint.label()));
        }
        boolean z = true;
        try {
            new DMatrix(linkedList.iterator(), (String) null);
        } catch (XGBoostError e) {
            z = false;
        }
        TestCase.assertTrue(!z);
    }

    @Test
    public void testCreateFromFile() throws XGBoostError {
        DMatrix dMatrix = new DMatrix(writeResourceIntoTempFile("/agaricus.txt.test"));
        float[] label = dMatrix.getLabel();
        TestCase.assertTrue(dMatrix.rowNum() == ((long) label.length));
        float[] copyOf = Arrays.copyOf(label, label.length);
        dMatrix.setWeight(copyOf);
        TestCase.assertTrue(Arrays.equals(copyOf, dMatrix.getWeight()));
    }

    @Test
    public void testCreateFromCSR() throws XGBoostError {
        DMatrix dMatrix = new DMatrix(new long[]{0, 3, 7, 11}, new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}, new float[]{1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 3.0f, 5.0f, 3.0f, 1.0f, 2.0f, 5.0f}, DMatrix.SparseType.CSR);
        TestCase.assertTrue(dMatrix.rowNum() == 3);
        float[] fArr = {1.0f, 0.0f, 1.0f};
        dMatrix.setLabel(fArr);
        TestCase.assertTrue(Arrays.equals(fArr, dMatrix.getLabel()));
    }

    @Test
    public void testCreateFromCSREx() throws XGBoostError {
        DMatrix dMatrix = new DMatrix(new long[]{0, 3, 7, 11}, new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}, new float[]{1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 3.0f, 5.0f, 3.0f, 1.0f, 2.0f, 5.0f}, DMatrix.SparseType.CSR, 5);
        TestCase.assertTrue(dMatrix.rowNum() == 3);
        float[] fArr = {1.0f, 0.0f, 1.0f};
        dMatrix.setLabel(fArr);
        TestCase.assertTrue(Arrays.equals(fArr, dMatrix.getLabel()));
    }

    @Test
    public void testCreateFromCSC() throws XGBoostError {
        DMatrix dMatrix = new DMatrix(new long[]{0, 4, 7, 11}, new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3}, new float[]{1.0f, 3.0f, 5.0f, 2.0f, 2.0f, 3.0f, 5.0f, 2.0f, 4.0f, 3.0f, 1.0f}, DMatrix.SparseType.CSC);
        System.out.println(dMatrix.rowNum());
        TestCase.assertTrue(dMatrix.rowNum() == 5);
        float[] fArr = {1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
        dMatrix.setLabel(fArr);
        TestCase.assertTrue(Arrays.equals(fArr, dMatrix.getLabel()));
    }

    @Test
    public void testCreateFromCSCEx() throws XGBoostError {
        DMatrix dMatrix = new DMatrix(new long[]{0, 4, 7, 11}, new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3}, new float[]{1.0f, 3.0f, 5.0f, 2.0f, 2.0f, 3.0f, 5.0f, 2.0f, 4.0f, 3.0f, 1.0f}, DMatrix.SparseType.CSC, 5);
        System.out.println(dMatrix.rowNum());
        TestCase.assertTrue(dMatrix.rowNum() == 5);
        float[] fArr = {1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
        dMatrix.setLabel(fArr);
        TestCase.assertTrue(Arrays.equals(fArr, dMatrix.getLabel()));
    }

    @Test
    public void testCreateFromDenseMatrix() throws XGBoostError {
        float[] fArr = new float[10 * 5];
        Random random = new Random();
        for (int i = 0; i < 10 * 5; i++) {
            fArr[i] = random.nextFloat();
        }
        float[] fArr2 = new float[10];
        for (int i2 = 0; i2 < 10; i2++) {
            fArr2[i2] = random.nextFloat();
        }
        DMatrix dMatrix = new DMatrix(fArr, 10, 5, Float.NaN);
        dMatrix.setLabel(fArr2);
        TestCase.assertTrue(dMatrix.rowNum() == 10);
        TestCase.assertTrue(dMatrix.getLabel().length == 10);
        float[] fArr3 = new float[10];
        for (int i3 = 0; i3 < 10; i3++) {
            fArr3[i3] = random.nextFloat();
        }
        dMatrix.setWeight(fArr3);
        TestCase.assertTrue(Arrays.equals(fArr3, dMatrix.getWeight()));
    }

    @Test
    public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
        float[] fArr = new float[10 * 5];
        Random random = new Random();
        for (int i = 0; i < 10 * 5; i++) {
            if (i % 10 == 0) {
                fArr[i] = -0.1f;
            } else {
                fArr[i] = random.nextFloat();
            }
        }
        float[] fArr2 = new float[10];
        for (int i2 = 0; i2 < 10; i2++) {
            fArr2[i2] = random.nextFloat();
        }
        DMatrix dMatrix = new DMatrix(fArr, 10, 5, -0.1f);
        dMatrix.setLabel(fArr2);
        TestCase.assertTrue(dMatrix.rowNum() == 10);
        TestCase.assertTrue(dMatrix.getLabel().length == 10);
    }

    @Test
    public void testCreateFromDenseMatrixRef() throws XGBoostError {
        DMatrix dMatrix = null;
        BigDenseMatrix bigDenseMatrix = null;
        try {
            bigDenseMatrix = new BigDenseMatrix(10, 5);
            Random random = new Random();
            for (int i = 0; i < 50; i++) {
                bigDenseMatrix.set(i, random.nextFloat());
            }
            float[] fArr = new float[10];
            for (int i2 = 0; i2 < 10; i2++) {
                fArr[i2] = random.nextFloat();
            }
            dMatrix = new DMatrix(bigDenseMatrix, Float.NaN);
            dMatrix.setLabel(fArr);
            TestCase.assertTrue(dMatrix.rowNum() == 10);
            TestCase.assertTrue(dMatrix.getLabel().length == 10);
            if (dMatrix != null) {
                dMatrix.dispose();
            } else if (bigDenseMatrix != null) {
                bigDenseMatrix.dispose();
            }
        } catch (Throwable th) {
            if (dMatrix != null) {
                dMatrix.dispose();
            } else if (bigDenseMatrix != null) {
                bigDenseMatrix.dispose();
            }
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testTrainWithDenseMatrixRef() throws XGBoostError {
        HashMap hashMap = new HashMap();
        hashMap.put("DMLC_TASK_ID", "0");
        Rabit.init(hashMap);
        DMatrix dMatrix = null;
        BigDenseMatrix bigDenseMatrix = null;
        try {
            float[] fArr = {new float[]{4.0f, 5.0f}, new float[]{3.0f, 1.0f}, new float[]{2.0f, 3.0f}};
            bigDenseMatrix = new BigDenseMatrix(3, 2);
            for (int i = 0; i < bigDenseMatrix.nrow; i++) {
                for (int i2 = 0; i2 < bigDenseMatrix.ncol; i2++) {
                    bigDenseMatrix.set(i, i2, fArr[i][i2]);
                }
            }
            dMatrix = new DMatrix(bigDenseMatrix, Float.NaN);
            dMatrix.setLabel(new float[]{1.0f, 2.0f, 3.0f});
            HashMap hashMap2 = new HashMap();
            hashMap2.put("eta", 1);
            hashMap2.put("max_depth", 5);
            hashMap2.put("silent", 1);
            hashMap2.put("objective", "reg:linear");
            hashMap2.put("seed", 123);
            HashMap hashMap3 = new HashMap();
            hashMap3.put("train", dMatrix);
            Booster train = XGBoost.train(dMatrix, hashMap2, 10, hashMap3, (IObjective) null, (IEvaluation) null);
            for (int i3 = 0; i3 < 3; i3++) {
                float[][] predict = train.predict(new DMatrix(fArr[i3], 1, 2, Float.NaN));
                Assert.assertEquals(1L, predict.length);
                Assert.assertArrayEquals(new float[]{i3 + 1}, predict[0], 0.01f);
            }
            if (dMatrix != null) {
                dMatrix.dispose();
            } else if (bigDenseMatrix != null) {
                bigDenseMatrix.dispose();
            }
            Rabit.shutdown();
        } catch (Throwable th) {
            if (dMatrix != null) {
                dMatrix.dispose();
            } else if (bigDenseMatrix != null) {
                bigDenseMatrix.dispose();
            }
            Rabit.shutdown();
            throw th;
        }
    }

    private String writeResourceIntoTempFile(String str) {
        InputStream resourceAsStream = getClass().getResourceAsStream(str);
        if (resourceAsStream == null) {
            throw new IllegalArgumentException("Resource " + str + " does not exist.");
        }
        try {
            File createTempFile = File.createTempFile("junit", ".test");
            byte[] bArr = new byte[1024];
            try {
                FileOutputStream fileOutputStream = new FileOutputStream(createTempFile);
                Throwable th = null;
                while (true) {
                    try {
                        try {
                            int read = resourceAsStream.read(bArr);
                            if (read <= 0) {
                                break;
                            }
                            fileOutputStream.write(bArr, 0, read);
                        } finally {
                        }
                    } finally {
                    }
                }
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                return createTempFile.getAbsolutePath();
            } catch (IOException e) {
                throw new RuntimeException("Unable to write to temp file.", e);
            }
        } catch (IOException e2) {
            throw new RuntimeException("Unable to write to temp file.", e2);
        }
    }

    @Test
    public void testSetAndGetGroup() throws XGBoostError {
        float[] fArr = new float[10 * 5];
        Random random = new Random();
        for (int i = 0; i < 10 * 5; i++) {
            fArr[i] = random.nextFloat();
        }
        float[] fArr2 = new float[10];
        for (int i2 = 0; i2 < 10; i2++) {
            fArr2[i2] = random.nextFloat();
        }
        DMatrix dMatrix = new DMatrix(fArr, 10, 5, -0.1f);
        dMatrix.setLabel(fArr2);
        dMatrix.setGroup(new int[]{5, 5});
        TestCase.assertTrue(Arrays.equals(new int[]{0, 5, 10}, dMatrix.getGroup()));
    }
}
