package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/tree/RegressionTreeLearner.class */
public class RegressionTreeLearner<InputType> extends AbstractDecisionTreeLearner<InputType, Double> implements SupervisedBatchLearner<InputType, Double, RegressionTree<InputType>> {
    public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 4;
    public static final int DEFAULT_MAX_DEPTH = -1;
    protected BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Double>>, ? extends Evaluator<? super InputType, Double>> regressionLearner;
    protected int leafCountThreshold;
    protected int maxDepth;

    public RegressionTreeLearner() {
        this(null);
    }

    public RegressionTreeLearner(DeciderLearner<? super InputType, Double, ?, ?> deciderLearner) {
        this(deciderLearner, null);
    }

    public RegressionTreeLearner(DeciderLearner<? super InputType, Double, ?, ?> deciderLearner, BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Double>>, ? extends Evaluator<? super InputType, Double>> batchLearner) {
        this(deciderLearner, batchLearner, 4, -1);
    }

    public RegressionTreeLearner(DeciderLearner<? super InputType, Double, ?, ?> deciderLearner, BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Double>>, ? extends Evaluator<? super InputType, Double>> batchLearner, int i, int i2) {
        super(deciderLearner);
        setRegressionLearner(batchLearner);
        setLeafCountThreshold(i);
        setMaxDepth(i2);
    }

    @Override // gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public RegressionTreeLearner<InputType> mo130clone() {
        RegressionTreeLearner<InputType> regressionTreeLearner = (RegressionTreeLearner) super.mo130clone();
        regressionTreeLearner.regressionLearner = (BatchLearner) ObjectUtil.cloneSafe(this.regressionLearner);
        return regressionTreeLearner;
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public RegressionTree<InputType> learn(Collection<? extends InputOutputPair<? extends InputType, Double>> collection) {
        if (collection == null) {
            return null;
        }
        return new RegressionTree<>(learnNode((Collection) collection, (AbstractDecisionTreeNode) null));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
    public RegressionTreeNode<InputType, ?> learnNode(Collection<? extends InputOutputPair<? extends InputType, Double>> collection, AbstractDecisionTreeNode<InputType, Double, ?> abstractDecisionTreeNode) {
        if (collection == 0 || collection.size() <= 0) {
            return null;
        }
        boolean z = collection.size() <= this.leafCountThreshold || (this.maxDepth > 0 && (abstractDecisionTreeNode == null ? 1 : 1 + abstractDecisionTreeNode.getDepth()) >= this.maxDepth) || areAllOutputsEqual(collection);
        double computeOutputMean = DatasetUtil.computeOutputMean(collection);
        Categorizer categorizer = null;
        if (!z) {
            categorizer = (Categorizer) getDeciderLearner().learn(collection);
        }
        if (!z && categorizer != null) {
            RegressionTreeNode<InputType, ?> regressionTreeNode = new RegressionTreeNode<>((DecisionTreeNode) abstractDecisionTreeNode, (Categorizer<? super InputType, ? extends Object>) categorizer, computeOutputMean);
            learnChildNodes(regressionTreeNode, collection, categorizer);
            return regressionTreeNode;
        }
        Evaluator<? super InputType, Double> evaluator = null;
        if (this.regressionLearner != null) {
            evaluator = this.regressionLearner.learn(collection);
        }
        return new RegressionTreeNode<>(abstractDecisionTreeNode, evaluator, computeOutputMean);
    }

    public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Double>>, ? extends Evaluator<? super InputType, Double>> getRegressionLearner() {
        return this.regressionLearner;
    }

    public void setRegressionLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Double>>, ? extends Evaluator<? super InputType, Double>> batchLearner) {
        this.regressionLearner = batchLearner;
    }

    public int getLeafCountThreshold() {
        return this.leafCountThreshold;
    }

    public void setLeafCountThreshold(int i) {
        ArgumentChecker.assertIsNonNegative("leafCountThreshold", i);
        this.leafCountThreshold = i;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMaxDepth(int i) {
        this.maxDepth = i;
    }
}
