package hex.psvm;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.algos.psvm.KernelParameters;
import hex.genmodel.algos.psvm.KernelType;
import hex.genmodel.algos.psvm.ScorerFactory;
import hex.genmodel.algos.psvm.SupportVectorScorer;
import hex.psvm.psvm.Kernel;
import hex.psvm.psvm.KernelFactory;
import hex.psvm.psvm.PrimalDualIPM;
import java.util.Arrays;
import water.Futures;
import water.Key;
import water.Keyed;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/psvm/PSVMModel.class */
public class PSVMModel extends Model<PSVMModel, PSVMParameters, PSVMModelOutput> {
    private transient SupportVectorScorer _scorer;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/psvm/PSVMModel$PSVMModelOutput.class */
    public static class PSVMModelOutput extends Model.Output {
        public long _svs_count;
        public long _bsv_count;
        public double _rho;
        public Key<Frame> _alpha_key;
        public byte[] _compressed_svs;

        /* JADX INFO: Access modifiers changed from: package-private */
        public PSVMModelOutput(PSVM psvm, Frame frame, String[] strArr) {
            super(psvm, frame);
            this._domains[this._domains.length - 1] = strArr != null ? strArr : new String[]{"-1", "+1"};
        }

        @Override // hex.Model.Output
        public ModelCategory getModelCategory() {
            return ModelCategory.Binomial;
        }
    }

    /* loaded from: input_file:hex/psvm/PSVMModel$PSVMParameters.class */
    public static class PSVMParameters extends Model.Parameters {
        private static final PrimalDualIPM.Parms IPM_DEFAULTS = new PrimalDualIPM.Parms();
        public long _seed = -1;
        public double _hyper_param = 1.0d;
        public double _positive_weight = 1.0d;
        public double _negative_weight = 1.0d;
        public double _sv_threshold = 1.0E-4d;
        public double _zero_threshold = 1.0E-9d;
        public boolean _disable_training_metrics = true;
        public KernelType _kernel_type = KernelType.gaussian;
        public double _gamma = -1.0d;
        public double _rank_ratio = -1.0d;
        public double _fact_threshold = 1.0E-5d;
        public int _max_iterations = IPM_DEFAULTS._max_iter;
        public double _feasible_threshold = IPM_DEFAULTS._feasible_threshold;
        public double _surrogate_gap_threshold = IPM_DEFAULTS._feasible_threshold;
        public double _mu_factor = IPM_DEFAULTS._mu_factor;

        @Override // hex.Model.Parameters
        public String algoName() {
            return "PSVM";
        }

        @Override // hex.Model.Parameters
        public String fullName() {
            return "PSVM";
        }

        @Override // hex.Model.Parameters
        public String javaName() {
            return PSVMModel.class.getName();
        }

        @Override // hex.Model.Parameters
        public long progressUnits() {
            return 1L;
        }

        public Kernel kernel() {
            return KernelFactory.make(this._kernel_type, kernelParms());
        }

        KernelParameters kernelParms() {
            KernelParameters kernelParameters = new KernelParameters();
            kernelParameters._gamma = this._gamma;
            return kernelParameters;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public PrimalDualIPM.Parms ipmParms() {
            PrimalDualIPM.Parms parms = new PrimalDualIPM.Parms();
            parms._max_iter = this._max_iterations;
            parms._mu_factor = this._mu_factor;
            parms._feasible_threshold = this._feasible_threshold;
            parms._sgap_threshold = this._surrogate_gap_threshold;
            parms._x_epsilon = this._zero_threshold;
            parms._c_pos = this._hyper_param * this._positive_weight;
            parms._c_neg = this._hyper_param * this._negative_weight;
            return parms;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double c_pos() {
            return this._hyper_param * this._positive_weight;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double c_neg() {
            return this._hyper_param * this._negative_weight;
        }
    }

    /* loaded from: input_file:hex/psvm/PSVMModel$SVMBigScoreChunkPredict.class */
    private class SVMBigScoreChunkPredict implements Model.BigScoreChunkPredict {
        private final double[] _scores;

        private SVMBigScoreChunkPredict(double[] dArr) {
            this._scores = dArr;
        }

        @Override // hex.Model.BigScoreChunkPredict
        public double[] score0(Chunk[] chunkArr, double d, int i, double[] dArr, double[] dArr2) {
            return PSVMModel.this.makePreds(this._scores[i], dArr2);
        }

        @Override // hex.Model.BigScoreChunkPredict, java.lang.AutoCloseable
        public void close() {
        }
    }

    /* loaded from: input_file:hex/psvm/PSVMModel$SVMBigScorePredict.class */
    private class SVMBigScorePredict implements Model.BigScorePredict {
        private BulkSupportVectorScorer _bulkScorer;

        SVMBigScorePredict(BulkSupportVectorScorer bulkSupportVectorScorer) {
            this._bulkScorer = bulkSupportVectorScorer;
        }

        @Override // hex.Model.BigScorePredict
        public Model.BigScoreChunkPredict initMap(Frame frame, Chunk[] chunkArr) {
            return new SVMBigScoreChunkPredict(this._bulkScorer.bulkScore0(chunkArr));
        }
    }

    public PSVMModel(Key<PSVMModel> key, PSVMParameters pSVMParameters, PSVMModelOutput pSVMModelOutput) {
        super(key, pSVMParameters, pSVMModelOutput);
        if (!$assertionsDisabled && !Arrays.equals(this._key._kb, key._kb)) {
            throw new AssertionError();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ModelMetricsSupervised makeModelMetrics(Frame frame, Frame frame2, String str) {
        Log.info("Making metrics: " + str);
        ModelMetricsSupervised modelMetricsSupervised = (ModelMetricsSupervised) scoreMetrics(frame2).makeModelMetrics(this, frame, frame2, null);
        modelMetricsSupervised._description = str;
        return modelMetricsSupervised;
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2) {
        return makePreds(getScorer().score0(dArr), dArr2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] makePreds(double d, double[] dArr) {
        int i = d + ((PSVMModelOutput) this._output)._rho < 0.0d ? 0 : 1;
        dArr[0] = i;
        dArr[1 + i] = 1.0d;
        dArr[2 - i] = 0.0d;
        return dArr;
    }

    @Override // hex.Model
    protected Model.BigScorePredict setupBigScorePredict(Model<PSVMModel, PSVMParameters, PSVMModelOutput>.BigScore bigScore) {
        return new SVMBigScorePredict(BulkScorerFactory.makeScorer(((PSVMParameters) this._parms)._kernel_type, ((PSVMParameters) this._parms).kernelParms(), ((PSVMModelOutput) this._output)._compressed_svs, (int) ((PSVMModelOutput) this._output)._svs_count, true));
    }

    private SupportVectorScorer getScorer() {
        SupportVectorScorer supportVectorScorer = this._scorer;
        if (supportVectorScorer == null) {
            SupportVectorScorer makeScorer = ScorerFactory.makeScorer(((PSVMParameters) this._parms)._kernel_type, ((PSVMParameters) this._parms).kernelParms(), ((PSVMModelOutput) this._output)._compressed_svs);
            supportVectorScorer = makeScorer;
            this._scorer = makeScorer;
        }
        return supportVectorScorer;
    }

    @Override // hex.Model
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new MetricBuilderPSVM(strArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Futures remove_impl(Futures futures, boolean z) {
        Keyed.remove(((PSVMModelOutput) this._output)._alpha_key, futures, true);
        return super.remove_impl(futures, z);
    }

    static {
        $assertionsDisabled = !PSVMModel.class.desiredAssertionStatus();
    }
}
