/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class XGBoost {
    private static final Log logger = LogFactory.getLog(XGBoost.class);

    public static Booster loadModel(String modelPath) throws XGBoostError {
        return Booster.loadModel(modelPath);
    }

    public static Booster loadModel(InputStream in) throws XGBoostError, IOException {
        return Booster.loadModel(in);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, IObjective obj, IEvaluation eval2) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, null, obj, eval2, 0);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval2, int earlyStoppingRound) throws XGBoostError {
        return XGBoost.train(dtrain, params, round, watches, metrics, obj, eval2, earlyStoppingRound, null);
    }

    public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval2, int earlyStoppingRound, Booster booster) throws XGBoostError {
        DMatrix[] allMats;
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<DMatrix> mats = new ArrayList<DMatrix>();
        for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
            names.add(evalEntry.getKey());
            mats.add(evalEntry.getValue());
        }
        String[] evalNames = names.toArray(new String[names.size()]);
        DMatrix[] evalMats = mats.toArray(new DMatrix[mats.size()]);
        float[][] fArray = metrics = metrics == null ? new float[evalNames.length][round] : metrics;
        if (evalMats.length > 0) {
            allMats = new DMatrix[evalMats.length + 1];
            allMats[0] = dtrain;
            System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
        } else {
            allMats = new DMatrix[]{dtrain};
        }
        if (booster == null) {
            booster = new Booster(params, allMats);
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(params);
        }
        for (int iter = booster.getVersion() / 2; iter < round; ++iter) {
            if (booster.getVersion() % 2 == 0) {
                if (obj != null) {
                    booster.update(dtrain, obj);
                } else {
                    booster.update(dtrain, iter);
                }
                booster.saveRabitCheckpoint();
            }
            if (evalMats.length > 0) {
                float[] metricsOut = new float[evalMats.length];
                String evalInfo = eval2 != null ? booster.evalSet(evalMats, evalNames, eval2, metricsOut) : booster.evalSet(evalMats, evalNames, iter, metricsOut);
                for (int i2 = 0; i2 < metricsOut.length; ++i2) {
                    metrics[i2][iter] = metricsOut[i2];
                }
                boolean decreasing = true;
                float[] criterion = metrics[metrics.length - 1];
                for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; ++shift) {
                    decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
                }
                if (!decreasing) {
                    Rabit.trackerPrint(String.format("early stopping after %d decreasing rounds", earlyStoppingRound));
                    break;
                }
                if (Rabit.getRank() == 0) {
                    Rabit.trackerPrint(evalInfo + '\n');
                }
            }
            booster.saveRabitCheckpoint();
        }
        return booster;
    }

    public static String[] crossValidation(DMatrix data, Map<String, Object> params, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval2) throws XGBoostError {
        CVPack[] cvPacks = XGBoost.makeNFold(data, nfold, params, metrics);
        String[] evalHist = new String[round];
        String[] results = new String[cvPacks.length];
        for (int i2 = 0; i2 < round; ++i2) {
            for (CVPack cvPack : cvPacks) {
                if (obj != null) {
                    cvPack.update(obj);
                    continue;
                }
                cvPack.update(i2);
            }
            for (int j2 = 0; j2 < cvPacks.length; ++j2) {
                results[j2] = eval2 != null ? cvPacks[j2].eval(eval2) : cvPacks[j2].eval(i2);
            }
            evalHist[i2] = XGBoost.aggCVResults(results);
            logger.info((Object)evalHist[i2]);
        }
        return evalHist;
    }

    private static CVPack[] makeNFold(DMatrix data, int nfold, Map<String, Object> params, String[] evalMetrics) throws XGBoostError {
        List<Integer> samples = XGBoost.genRandPermutationNums(0, (int)data.rowNum());
        int step = samples.size() / nfold;
        int[] testSlice = new int[step];
        int[] trainSlice = new int[samples.size() - step];
        CVPack[] cvPacks = new CVPack[nfold];
        for (int i2 = 0; i2 < nfold; ++i2) {
            int testid = 0;
            int trainid = 0;
            for (int j2 = 0; j2 < samples.size(); ++j2) {
                if (j2 > i2 * step && j2 < i2 * step + step && testid < step) {
                    testSlice[testid] = samples.get(j2);
                    ++testid;
                    continue;
                }
                if (trainid < samples.size() - step) {
                    trainSlice[trainid] = samples.get(j2);
                    ++trainid;
                    continue;
                }
                testSlice[testid] = samples.get(j2);
                ++testid;
            }
            DMatrix dtrain = data.slice(trainSlice);
            DMatrix dtest = data.slice(testSlice);
            CVPack cvPack = new CVPack(dtrain, dtest, params);
            if (evalMetrics != null) {
                for (String type : evalMetrics) {
                    cvPack.booster.setParam("eval_metric", type);
                }
            }
            cvPacks[i2] = cvPack;
        }
        return cvPacks;
    }

    private static List<Integer> genRandPermutationNums(int start2, int end) {
        ArrayList<Integer> samples = new ArrayList<Integer>();
        for (int i2 = start2; i2 < end; ++i2) {
            samples.add(i2);
        }
        Collections.shuffle(samples);
        return samples;
    }

    private static String aggCVResults(String[] results) {
        HashMap cvMap = new HashMap();
        String aggResult = results[0].split("\t")[0];
        for (String result : results) {
            String[] items = result.split("\t");
            for (int i2 = 1; i2 < items.length; ++i2) {
                String[] tup = items[i2].split(":");
                String key = tup[0];
                Float value = Float.valueOf(tup[1]);
                if (!cvMap.containsKey(key)) {
                    cvMap.put(key, new ArrayList());
                }
                ((List)cvMap.get(key)).add(value);
            }
        }
        for (String key : cvMap.keySet()) {
            float value = 0.0f;
            for (Float tvalue : (List)cvMap.get(key)) {
                value += tvalue.floatValue();
            }
            aggResult = aggResult + String.format("\tcv-%s:%f", key, Float.valueOf(value /= (float)((List)cvMap.get(key)).size()));
        }
        return aggResult;
    }

    private static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names;
        Booster booster;

        public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params) throws XGBoostError {
            this.dmats = new DMatrix[]{dtrain, dtest};
            this.booster = new Booster(params, this.dmats);
            this.names = new String[]{"train", "test"};
            this.dtrain = dtrain;
            this.dtest = dtest;
        }

        public void update(int iter) throws XGBoostError {
            this.booster.update(this.dtrain, iter);
        }

        public void update(IObjective obj) throws XGBoostError {
            this.booster.update(this.dtrain, obj);
        }

        public String eval(int iter) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iter);
        }

        public String eval(IEvaluation eval2) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, eval2);
        }
    }
}

