package hex;

import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Iced;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Vec;

/* loaded from: input_file:hex/MeanResidualDeviance.class */
public class MeanResidualDeviance extends Iced {
    public Vec _actuals;
    public Vec _preds;
    public Vec _weights;
    public Distribution _dist;
    public double meanResidualDeviance;

    /* loaded from: input_file:hex/MeanResidualDeviance$MeanResidualBuilder.class */
    public static class MeanResidualBuilder extends MRTask<MeanResidualBuilder> {
        public double _mean_residual_deviance;
        private double _wcount;
        private Distribution _dist;
        static final /* synthetic */ boolean $assertionsDisabled;

        MeanResidualBuilder(Distribution distribution) {
            this._dist = distribution;
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            map(chunk, chunk2, (Chunk) null);
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2, Chunk chunk3) {
            this._mean_residual_deviance = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this._wcount = CMAESOptimizer.DEFAULT_STOPFITNESS;
            int min = Math.min(chunk._len, chunk2._len);
            for (int i = 0; i < min; i++) {
                if (!chunk.isNA(i) && !chunk2.isNA(i)) {
                    perRow(chunk2.atd(i), chunk.atd(i), chunk3 != null ? chunk3.atd(i) : 1.0d);
                }
            }
        }

        public void perRow(double d, double d2, double d3) {
            if (d3 == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return;
            }
            if (!$assertionsDisabled && Double.isNaN(d)) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && Double.isNaN(d2)) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && Double.isNaN(d3)) {
                throw new AssertionError();
            }
            this._mean_residual_deviance += this._dist.deviance(d3, d2, d);
            this._wcount += d3;
        }

        @Override // water.MRTask
        public void reduce(MeanResidualBuilder meanResidualBuilder) {
            this._mean_residual_deviance += meanResidualBuilder._mean_residual_deviance;
            this._wcount += meanResidualBuilder._wcount;
        }

        @Override // water.MRTask
        public void postGlobal() {
            this._mean_residual_deviance /= this._wcount;
        }

        static {
            $assertionsDisabled = !MeanResidualDeviance.class.desiredAssertionStatus();
        }
    }

    public MeanResidualDeviance(Distribution distribution, Vec vec, Vec vec2, Vec vec3) {
        this._preds = vec;
        this._actuals = vec2;
        this._weights = vec3;
        this._dist = distribution;
    }

    private void init() throws IllegalArgumentException {
        if (this._actuals == null || this._preds == null) {
            throw new IllegalArgumentException("Missing actual targets or predicted values!");
        }
        if (this._actuals.length() != this._preds.length()) {
            throw new IllegalArgumentException("Both arguments must have the same length (" + this._actuals.length() + "!=" + this._preds.length() + ")!");
        }
        if (!this._actuals.isNumeric()) {
            throw new IllegalArgumentException("Actual target column must be numeric!");
        }
        if (this._preds.isCategorical()) {
            throw new IllegalArgumentException("Predicted targets cannot be class labels, expect continuous values.");
        }
        if (this._weights != null && !this._weights.isNumeric()) {
            throw new IllegalArgumentException("Observation weights must be numeric.");
        }
        if (this._actuals.group().equals(this._preds.group())) {
            return;
        }
        this._preds = this._actuals.align(this._preds);
        Scope.track(this._preds);
        if (this._weights != null) {
            this._weights = this._actuals.align(this._weights);
            Scope.track(this._weights);
        }
    }

    public MeanResidualDeviance exec() {
        Scope.enter();
        init();
        try {
            MeanResidualBuilder meanResidualBuilder = new MeanResidualBuilder(this._dist);
            this.meanResidualDeviance = (this._weights != null ? meanResidualBuilder.doAll(this._actuals, this._preds, this._weights) : meanResidualBuilder.doAll(this._actuals, this._preds))._mean_residual_deviance;
            Scope.exit(new Key[0]);
            return this;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
