package hex.util;

import hex.Model;
import hex.ModelBuilder;
import java.lang.reflect.Field;
import java.util.Arrays;
import org.apache.log4j.xml.XmlConfiguration;
import water.Value;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.PojoUtils;

/* loaded from: input_file:hex/util/CheckpointUtils.class */
public class CheckpointUtils {
    private static void validateWithCheckpoint(Model.Parameters parameters, String[] strArr, Model.Parameters parameters2) {
        for (Field field : parameters.getClass().getFields()) {
            if (ArrayUtils.contains(strArr, field.getName())) {
                for (Field field2 : parameters2.getClass().getFields()) {
                    if (field2.equals(field)) {
                        try {
                            if (!PojoUtils.equals(parameters, field, parameters2, parameters2.getClass().getField(field.getName()))) {
                                throw new H2OIllegalArgumentException(field.getName(), "TreeBuilder", "Field " + field.getName() + " cannot be modified if checkpoint is specified!");
                            }
                        } catch (NoSuchFieldException e) {
                            throw new H2OIllegalArgumentException(field.getName(), "TreeBuilder", "Field " + field.getName() + " is not supported by checkpoint!");
                        }
                    }
                }
            }
        }
    }

    private static void validateNTrees(ModelBuilder modelBuilder, Model.GetNTrees getNTrees, Model.GetNTrees getNTrees2) {
        if (getNTrees.getNTrees() < getNTrees2.getNTrees() + 1) {
            modelBuilder.error("_ntrees", "If checkpoint is specified then requested ntrees must be higher than " + (getNTrees2.getNTrees() + 1));
        }
    }

    public static <M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> M getAndValidateCheckpointModel(ModelBuilder<M, P, O> modelBuilder, String[] strArr, Value value) {
        M m = (M) value.get();
        try {
            validateWithCheckpoint(modelBuilder._parms, strArr, m._input_parms);
        } catch (H2OIllegalArgumentException e) {
            modelBuilder.error(e.values.get("argument").toString(), e.values.get(XmlConfiguration.VALUE_ATTR).toString());
        }
        if (modelBuilder.isClassifier() != m._output.isClassifier()) {
            throw new IllegalArgumentException("Response type must be the same as for the checkpointed model.");
        }
        if (!Arrays.equals(modelBuilder.train().names(), m._output._names)) {
            throw new IllegalArgumentException("The columns of the training data must be the same as for the checkpointed model");
        }
        if (!Arrays.deepEquals(modelBuilder.train().domains(), m._output._domains)) {
            throw new IllegalArgumentException("Categorical factor levels of the training data must be the same as for the checkpointed model");
        }
        if ((modelBuilder._parms instanceof Model.GetNTrees) && (m._output instanceof Model.GetNTrees)) {
            validateNTrees(modelBuilder, (Model.GetNTrees) modelBuilder._parms, (Model.GetNTrees) m._output);
        }
        return m;
    }
}
