/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.word2vec.WordEmbeddingModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import hex.genmodel.easy.prediction.Word2VecPrediction;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.Assert;
import org.junit.Test;

public class EasyPredictModelWrapperTest {
    private static MyModel makeModel() {
        String[] names = new String[]{"C1", "C2", "RESPONSE"};
        String[][] domains = new String[][]{{"c1level1", "c1level2"}, {"c2level1", "c2level2", "c2level3"}, {"NO", "YES"}};
        return new MyModel(names, domains);
    }

    @Test
    public void testUnknownCategoricalLevels() throws Exception {
        MyModel rawModel = EasyPredictModelWrapperTest.makeModel();
        EasyPredictModelWrapper m = new EasyPredictModelWrapper((GenModel)rawModel);
        RowData row = new RowData();
        row.put((Object)"C1", (Object)"c1level1");
        try {
            m.predictBinomial(row);
        }
        catch (PredictUnknownCategoricalLevelException e) {
            Assert.fail((String)"Caught exception but should not have");
        }
        ConcurrentHashMap unknown = m.getUnknownCategoricalLevelsSeenPerColumn();
        long total = 0L;
        for (AtomicLong l : unknown.values()) {
            total += l.get();
        }
        Assert.assertEquals((long)total, (long)0L);
        row = new RowData();
        row.put((Object)"C1", (Object)"c1level1");
        row.put((Object)"C2", (Object)"unknownLevel");
        boolean caught = false;
        try {
            m.predictBinomial(row);
        }
        catch (PredictUnknownCategoricalLevelException e) {
            caught = true;
        }
        Assert.assertEquals((Object)caught, (Object)true);
        ConcurrentHashMap unknown2 = m.getUnknownCategoricalLevelsSeenPerColumn();
        long total2 = 0L;
        for (AtomicLong l : unknown2.values()) {
            total2 += l.get();
        }
        Assert.assertEquals((long)total2, (long)0L);
        m = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel((GenModel)rawModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(true));
        RowData row0 = new RowData();
        m.predict(row0);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)0L);
        RowData row1 = new RowData();
        row1.put((Object)"C1", (Object)"c1level1");
        row1.put((Object)"C2", (Object)"unknownLevel");
        m.predictBinomial(row1);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)1L);
        RowData row2 = new RowData();
        row2.put((Object)"C1", (Object)"c1level1");
        row2.put((Object)"C2", (Object)"c2level3");
        m.predictBinomial(row2);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)1L);
        RowData row3 = new RowData();
        row3.put((Object)"C1", (Object)"c1level1");
        row3.put((Object)"unknownColumn", (Object)"unknownLevel");
        m.predictBinomial(row3);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)1L);
        m.predictBinomial(row1);
        m.predictBinomial(row1);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)3L);
        RowData row4 = new RowData();
        row4.put((Object)"C1", (Object)"unknownLevel");
        m.predictBinomial(row4);
        Assert.assertEquals((long)m.getTotalUnknownCategoricalLevelsSeen(), (long)4L);
        Assert.assertEquals((long)((AtomicLong)m.getUnknownCategoricalLevelsSeenPerColumn().get("C1")).get(), (long)1L);
        Assert.assertEquals((long)((AtomicLong)m.getUnknownCategoricalLevelsSeenPerColumn().get("C2")).get(), (long)3L);
    }

    @Test
    public void testSortedClassProbability() throws Exception {
        MyModel rawModel = EasyPredictModelWrapperTest.makeModel();
        EasyPredictModelWrapper m = new EasyPredictModelWrapper((GenModel)rawModel);
        RowData row = new RowData();
        row.put((Object)"C1", (Object)"c1level1");
        BinomialModelPrediction p = m.predictBinomial(row);
        SortedClassProbability[] arr = m.sortByDescendingClassProbability(p);
        Assert.assertEquals((Object)arr[0].name, (Object)"NO");
        Assert.assertEquals((double)arr[0].probability, (double)1.0, (double)0.001);
        Assert.assertEquals((Object)arr[1].name, (Object)"YES");
        Assert.assertEquals((double)arr[1].probability, (double)0.0, (double)0.001);
    }

    @Test
    public void testWordEmbeddingModel() throws Exception {
        MyWordEmbeddingModel rawModel = new MyWordEmbeddingModel();
        EasyPredictModelWrapper m = new EasyPredictModelWrapper((GenModel)rawModel);
        RowData row = new RowData();
        row.put((Object)"C0", (Object)-1);
        row.put((Object)"C1", (Object)"0.9,0.1");
        row.put((Object)"C2", (Object)"0.1,0.9");
        row.put((Object)"C3", (Object)"NA");
        Word2VecPrediction p = m.predictWord2Vec(row);
        Assert.assertFalse((boolean)p.wordEmbeddings.containsKey("C0"));
        Assert.assertArrayEquals((float[])new float[]{0.9f, 0.1f}, (float[])((float[])p.wordEmbeddings.get("C1")), (float)1.0E-4f);
        Assert.assertArrayEquals((float[])new float[]{0.1f, 0.9f}, (float[])((float[])p.wordEmbeddings.get("C2")), (float)1.0E-4f);
        Assert.assertTrue((boolean)p.wordEmbeddings.containsKey("C3"));
        Assert.assertNull(p.wordEmbeddings.get("C3"));
    }

    @Test
    public void testAutoEncoderModel() throws Exception {
        MyAutoEncoderModel rawModel = new MyAutoEncoderModel();
        EasyPredictModelWrapper m = new EasyPredictModelWrapper((GenModel)rawModel);
        RowData row = new RowData();
        row.put((Object)"Species", (Object)"versicolor");
        row.put((Object)"Sepal.Length", (Object)7.0);
        row.put((Object)"Sepal.Width", (Object)3.2);
        row.put((Object)"Petal.Length", (Object)4.7);
        row.put((Object)"Petal.Width", (Object)1.4);
        AbstractPrediction p = m.predict(row);
        Assert.assertTrue((boolean)(p instanceof AutoEncoderModelPrediction));
        AutoEncoderModelPrediction aep = (AutoEncoderModelPrediction)p;
        Assert.assertArrayEquals((double[])new double[]{0.0, 1.0, 0.0, 0.0, 7.0, 3.2, 4.7, 1.4}, (double[])aep.original, (double)0.01);
        Assert.assertArrayEquals((double[])new double[]{0.0, 1.3124, 0.4864, 0.0, 6.1729, 3.0573, 17.8372, 1.1993}, (double[])aep.reconstructed, (double)0.001);
        HashMap<String, Object> expected = new HashMap<String, Object>(){
            {
                this.put("Petal.Length", 17.8372);
                this.put("Petal.Width", 1.1993);
                this.put("Sepal.Width", 3.0573);
                this.put("Sepal.Length", 6.1729);
                this.put("Species", new HashMap<String, Object>(){
                    {
                        this.put(null, 0.0);
                        this.put("setosa", 0.0);
                        this.put("virginica", 0.4864);
                        this.put("versicolor", 1.3124);
                    }
                });
            }
        };
        Assert.assertEquals((Object)expected, (Object)aep.reconstructedRowData);
    }

    private static class MyAutoEncoderModel
    extends GenModel {
        private static final String[][] DOMAINS = new String[][]{{"setosa", "versicolor", "virginica"}, null, null, null, null};

        private MyAutoEncoderModel() {
            super(new String[]{"Species", "Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"}, DOMAINS, null);
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.AutoEncoder;
        }

        public boolean isSupervised() {
            return false;
        }

        public int nfeatures() {
            return 5;
        }

        public int nclasses() {
            return 8;
        }

        public String getUUID() {
            return null;
        }

        public int getPredsSize() {
            return this.nclasses();
        }

        public double[] score0(double[] row, double[] preds) {
            double[] result = new double[]{0.0, 1.3124, 0.4864, 0.0, 6.1729, 3.0573, 17.8372, 1.1993};
            Assert.assertArrayEquals((double[])new double[]{1.0, 7.0, 3.2, 4.7, 1.4}, (double[])row, (double)1.0E-4);
            Assert.assertEquals((long)result.length, (long)preds.length);
            System.arraycopy(result, 0, preds, 0, result.length);
            return result;
        }
    }

    private static class MyWordEmbeddingModel
    extends MojoModel
    implements WordEmbeddingModel {
        public MyWordEmbeddingModel() {
            super(new String[0], (String[][])new String[0][], null);
        }

        public int getVecSize() {
            return 2;
        }

        public float[] transform0(String word, float[] output) {
            if (word.equals("NA")) {
                return null;
            }
            String[] words = word.split(",");
            for (int i = 0; i < words.length; ++i) {
                output[i] = Float.valueOf(words[i]).floatValue();
            }
            return output;
        }

        public double[] score0(double[] row, double[] preds) {
            throw new IllegalStateException("Should never be called");
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.WordEmbedding;
        }
    }

    private static class MyModel
    extends GenModel {
        MyModel(String[] names, String[][] domains) {
            super(names, domains, null);
        }

        public int nclasses() {
            return 2;
        }

        public boolean isSupervised() {
            return true;
        }

        public double[] score0(double[] data, double[] preds) {
            Assert.assertEquals((long)preds.length, (long)3L);
            preds[0] = 0.0;
            preds[1] = 1.0;
            preds[2] = 0.0;
            return preds;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.Binomial;
        }

        public String getUUID() {
            return null;
        }
    }
}

