package ml.dmlc.xgboost4j.scala.example;

import java.io.File;
import java.io.PrintWriter;
import java.util.Arrays;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.XGBoost$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.HashMap;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.NonLocalReturnControl;
import scala.runtime.RichInt$;

/* compiled from: BasicWalkThrough.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/example/BasicWalkThrough$.class */
public final class BasicWalkThrough$ {
    public static BasicWalkThrough$ MODULE$;

    static {
        new BasicWalkThrough$();
    }

    public void saveDumpModel(String str, String[] strArr) {
        PrintWriter printWriter = new PrintWriter(str, "UTF-8");
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), strArr.length).foreach$mVc$sp(i -> {
            printWriter.print(new StringBuilder(11).append("booster[").append(i).append("]:\n").toString());
            printWriter.print(strArr[i]);
        });
        printWriter.close();
    }

    public void main(String[] strArr) {
        DMatrix dMatrix = new DMatrix("../../demo/data/agaricus.txt.train");
        DMatrix dMatrix2 = new DMatrix("../../demo/data/agaricus.txt.test");
        HashMap hashMap = new HashMap();
        hashMap.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("eta"), BoxesRunTime.boxToDouble(1.0d)));
        hashMap.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("max_depth"), BoxesRunTime.boxToInteger(2)));
        hashMap.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("silent"), BoxesRunTime.boxToInteger(1)));
        hashMap.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("objective"), "binary:logistic"));
        HashMap hashMap2 = new HashMap();
        hashMap2.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), dMatrix));
        hashMap2.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("test"), dMatrix2));
        Booster train = XGBoost$.MODULE$.train(dMatrix, hashMap.toMap(Predef$.MODULE$.$conforms()), 2, hashMap2.toMap(Predef$.MODULE$.$conforms()), XGBoost$.MODULE$.train$default$5(), XGBoost$.MODULE$.train$default$6(), XGBoost$.MODULE$.train$default$7(), XGBoost$.MODULE$.train$default$8(), XGBoost$.MODULE$.train$default$9());
        float[][] predict = train.predict(dMatrix2, train.predict$default$2(), train.predict$default$3());
        File file = new File("./model");
        if (file.exists()) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxesRunTime.boxToBoolean(file.mkdirs());
        }
        train.saveModel(new StringBuilder(10).append(file.getAbsolutePath()).append("/xgb.model").toString());
        saveDumpModel(new StringBuilder(13).append(file.getAbsolutePath()).append("/dump.raw.txt").toString(), train.getModelDump(new StringBuilder(12).append(file.getAbsolutePath()).append("/featmap.txt").toString(), false, train.getModelDump$default$3()));
        dMatrix2.saveBinary(new StringBuilder(13).append(file.getAbsolutePath()).append("/dtest.buffer").toString());
        Booster loadModel = XGBoost$.MODULE$.loadModel(new StringBuilder(10).append(file.getAbsolutePath()).append("/xgb.model").toString());
        DMatrix dMatrix3 = new DMatrix(new StringBuilder(13).append(file.getAbsolutePath()).append("/dtest.buffer").toString());
        Predef$.MODULE$.println(BoxesRunTime.boxToBoolean(checkPredicts(predict, loadModel.predict(dMatrix3, loadModel.predict$default$2(), loadModel.predict$default$3()))));
        Predef$.MODULE$.println("start build dmatrix from csr sparse data ...");
        DataLoader.CSRSparseData loadSVMFile = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
        DMatrix dMatrix4 = new DMatrix(loadSVMFile.rowHeaders, loadSVMFile.colIndex, loadSVMFile.data, DMatrix.SparseType.CSR);
        dMatrix4.setLabel(loadSVMFile.labels);
        HashMap hashMap3 = new HashMap();
        hashMap3.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), dMatrix4));
        hashMap3.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("test"), dMatrix3));
        Booster train2 = XGBoost$.MODULE$.train(dMatrix4, hashMap.toMap(Predef$.MODULE$.$conforms()), 2, hashMap3.toMap(Predef$.MODULE$.$conforms()), XGBoost$.MODULE$.train$default$5(), XGBoost$.MODULE$.train$default$6(), XGBoost$.MODULE$.train$default$7(), XGBoost$.MODULE$.train$default$8(), XGBoost$.MODULE$.train$default$9());
        Predef$.MODULE$.println(BoxesRunTime.boxToBoolean(checkPredicts(predict, train2.predict(dMatrix3, train2.predict$default$2(), train2.predict$default$3()))));
    }

    public boolean checkPredicts(float[][] fArr, float[][] fArr2) {
        Object obj = new Object();
        try {
            Predef$.MODULE$.require(fArr.length == fArr2.length, () -> {
                return "the comparing predicts must be with the same length";
            });
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(fArr)).indices().foreach$mVc$sp(i -> {
                if (!Arrays.equals(fArr[i], fArr2[i])) {
                    throw new NonLocalReturnControl.mcZ.sp(obj, false);
                }
            });
            return true;
        } catch (NonLocalReturnControl e) {
            if (e.key() == obj) {
                return e.value$mcZ$sp();
            }
            throw e;
        }
    }

    private BasicWalkThrough$() {
        MODULE$ = this;
    }
}
