package ai.h2o.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.logging.log4j.util.ProcessIdUtil;

/* loaded from: input_file:ai/h2o/xgboost4j/java/ExternalCheckpointManager.class */
public class ExternalCheckpointManager {
    private Log logger = LogFactory.getLog("ExternalCheckpointManager");
    private String modelSuffix = ".model";
    private Path checkpointPath;
    private FileSystem fs;

    public ExternalCheckpointManager(String str, FileSystem fileSystem) throws XGBoostError {
        if (str == null || str.isEmpty()) {
            throw new XGBoostError("cannot create ExternalCheckpointManager with null or empty checkpoint path");
        }
        this.checkpointPath = new Path(str);
        this.fs = fileSystem;
    }

    private String getPath(int i) {
        return this.checkpointPath.toUri().getPath() + "/" + i + this.modelSuffix;
    }

    private List<Integer> getExistingVersions() throws IOException {
        return !this.fs.exists(this.checkpointPath) ? new ArrayList() : (List) Arrays.stream(this.fs.listStatus(this.checkpointPath)).map(fileStatus -> {
            return fileStatus.getPath().getName();
        }).filter(str -> {
            return str.endsWith(this.modelSuffix);
        }).map(str2 -> {
            return Integer.valueOf(str2.substring(0, str2.length() - this.modelSuffix.length()));
        }).collect(Collectors.toList());
    }

    public void cleanPath() throws IOException {
        this.fs.delete(this.checkpointPath, true);
    }

    public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
        List<Integer> existingVersions = getExistingVersions();
        if (existingVersions.size() <= 0) {
            return null;
        }
        int intValue = existingVersions.stream().max(Comparator.comparing((v0) -> {
            return Integer.valueOf(v0);
        })).get().intValue();
        String path = getPath(intValue);
        FSDataInputStream open = this.fs.open(new Path(path));
        this.logger.info("loaded checkpoint from " + path);
        Booster loadModel = XGBoost.loadModel((InputStream) open);
        loadModel.setVersion(intValue);
        return loadModel;
    }

    public void updateCheckpoint(Booster booster) throws IOException, XGBoostError {
        List list = (List) getExistingVersions().stream().map((v1) -> {
            return getPath(v1);
        }).collect(Collectors.toList());
        String path = getPath(booster.getVersion());
        String str = path + ProcessIdUtil.DEFAULT_PROCESSID + UUID.randomUUID();
        FSDataOutputStream create = this.fs.create(new Path(str), true);
        Throwable th = null;
        try {
            try {
                booster.saveModel((OutputStream) create);
                this.fs.rename(new Path(str), new Path(path));
                this.logger.info("saving checkpoint with version " + booster.getVersion());
                list.stream().forEach(str2 -> {
                    try {
                        this.fs.delete(new Path(str2), true);
                    } catch (IOException e) {
                        this.logger.error("failed to delete outdated checkpoint at " + str2, e);
                    }
                });
                if (create != null) {
                    if (0 == 0) {
                        create.close();
                        return;
                    }
                    try {
                        create.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (create != null) {
                if (th != null) {
                    try {
                        create.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    create.close();
                }
            }
            throw th4;
        }
    }

    public void cleanUpHigherVersions(int i) throws IOException {
        getExistingVersions().stream().filter(num -> {
            return num.intValue() / 2 >= i;
        }).forEach(num2 -> {
            try {
                this.fs.delete(new Path(getPath(num2.intValue())), true);
            } catch (IOException e) {
                this.logger.error("failed to clean checkpoint from other training instance", e);
            }
        });
    }

    public List<Integer> getCheckpointRounds(int i, int i2) throws IOException {
        if (i <= 0) {
            if (i > 0) {
                throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(Integer.valueOf(i2));
            return arrayList;
        }
        List list = (List) getExistingVersions().stream().map(num -> {
            return Integer.valueOf(num.intValue() / 2);
        }).collect(Collectors.toList());
        list.add(0);
        int intValue = ((Integer) list.stream().max(Comparator.comparing((v0) -> {
            return Integer.valueOf(v0);
        })).get()).intValue() + i;
        ArrayList arrayList2 = new ArrayList();
        int i3 = intValue;
        while (true) {
            int i4 = i3;
            if (i4 > i2) {
                arrayList2.add(Integer.valueOf(i2));
                return arrayList2;
            }
            arrayList2.add(Integer.valueOf(i4));
            i3 = i4 + i;
        }
    }
}
