package ml.dmlc.xgboost4j.java.example;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/example/CustomObjective.class */
public class CustomObjective {

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/example/CustomObjective$EvalError.class */
    public static class EvalError implements IEvaluation {
        private static final Log logger = LogFactory.getLog(EvalError.class);
        String evalMetric = "custom_error";

        public String getMetric() {
            return this.evalMetric;
        }

        public float eval(float[][] fArr, DMatrix dMatrix) {
            float f = 0.0f;
            try {
                float[] label = dMatrix.getLabel();
                int length = fArr.length;
                for (int i = 0; i < length; i++) {
                    if (label[i] == 0.0f && fArr[i][0] > 0.0f) {
                        f += 1.0f;
                    } else if (label[i] == 1.0f && fArr[i][0] <= 0.0f) {
                        f += 1.0f;
                    }
                }
                return f / label.length;
            } catch (XGBoostError e) {
                logger.error(e);
                return -1.0f;
            }
        }
    }

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/example/CustomObjective$LogRegObj.class */
    public static class LogRegObj implements IObjective {
        private static final Log logger = LogFactory.getLog(LogRegObj.class);

        public float sigmoid(float f) {
            return (float) (1.0d / (1.0d + Math.exp(-f)));
        }

        public float[][] transform(float[][] fArr) {
            int length = fArr.length;
            float[][] fArr2 = new float[length][1];
            for (int i = 0; i < length; i++) {
                fArr2[i][0] = sigmoid(fArr[i][0]);
            }
            return fArr2;
        }

        public List<float[]> getGradient(float[][] fArr, DMatrix dMatrix) {
            int length = fArr.length;
            ArrayList arrayList = new ArrayList();
            try {
                float[] label = dMatrix.getLabel();
                float[] fArr2 = new float[length];
                float[] fArr3 = new float[length];
                float[][] transform = transform(fArr);
                for (int i = 0; i < length; i++) {
                    float f = transform[i][0];
                    fArr2[i] = f - label[i];
                    fArr3[i] = f * (1.0f - f);
                }
                arrayList.add(fArr2);
                arrayList.add(fArr3);
                return arrayList;
            } catch (XGBoostError e) {
                logger.error(e);
                return null;
            }
        }
    }

    public static void main(String[] strArr) throws XGBoostError {
        DMatrix dMatrix = new DMatrix("../../demo/data/agaricus.txt.train");
        DMatrix dMatrix2 = new DMatrix("../../demo/data/agaricus.txt.test");
        HashMap hashMap = new HashMap();
        hashMap.put("eta", Double.valueOf(1.0d));
        hashMap.put("max_depth", 2);
        hashMap.put("silent", 1);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("train", dMatrix);
        hashMap2.put("test", dMatrix2);
        LogRegObj logRegObj = new LogRegObj();
        EvalError evalError = new EvalError();
        System.out.println("begin to train the booster model");
        XGBoost.train(dMatrix, hashMap, 2, hashMap2, logRegObj, evalError);
    }
}
