/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.svm;

import com.google.common.io.ByteStreams;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.algos.svm.SvmMojoReader;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class SvmMojoModelTest {
    private MojoModel _mojo;
    private double[][] _rows;
    private RowData[] _rowData;
    private double[] expectedPreds;

    @Before
    public void setup() throws IOException {
        this._mojo = SvmMojoReader.readFrom((MojoReaderBackend)new ClasspathReaderBackend());
        this._rows = new double[][]{{0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}};
        this.expectedPreds = new double[]{1.0, 0.0};
        this._rowData = new RowData[this._rows.length];
        for (int i = 0; i < this._rows.length; ++i) {
            this._rowData[i] = SvmMojoModelTest.toRowData(this._mojo, this._rows[i]);
        }
    }

    @Test
    public void testPredict() throws Exception {
        EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper((GenModel)this._mojo);
        for (int i = 0; i < this._rows.length; ++i) {
            BinomialModelPrediction p = (BinomialModelPrediction)wrapper.predict(this._rowData[i]);
            Assert.assertEquals((long)((int)this.expectedPreds[i]), (long)p.labelIndex);
            double[] preds = new double[3];
            this._mojo.score0(this._rows[i], preds);
            Assert.assertEquals((double)this.expectedPreds[i], (double)preds[0], (double)0.0);
        }
    }

    private static RowData toRowData(MojoModel mojo, double[] row) {
        RowData rowData = new RowData();
        for (int i = 0; i < mojo._names.length - 1; ++i) {
            String name = mojo._names[i];
            int idx = mojo.getColIdx(name);
            String[] domain = mojo.getDomainValues(idx);
            if (domain != null) {
                rowData.put((Object)name, (Object)domain[(int)row[idx]]);
                continue;
            }
            rowData.put((Object)name, (Object)row[idx]);
        }
        return rowData;
    }

    private static class ClasspathReaderBackend
    implements MojoReaderBackend {
        private ClasspathReaderBackend() {
        }

        public BufferedReader getTextFile(String filename) throws IOException {
            InputStream is = SvmMojoModelTest.class.getResourceAsStream(filename);
            return new BufferedReader(new InputStreamReader(is));
        }

        public byte[] getBinaryFile(String filename) throws IOException {
            InputStream is = SvmMojoModelTest.class.getResourceAsStream(filename);
            return ByteStreams.toByteArray((InputStream)is);
        }

        public boolean exists(String filename) {
            return true;
        }
    }
}

