/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.IClusteringModel;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import hex.genmodel.algos.word2vec.WordEmbeddingModel;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.exception.PredictNumberFormatException;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.exception.PredictUnknownTypeException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import hex.genmodel.easy.prediction.Word2VecPrediction;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import javax.imageio.ImageIO;

public class EasyPredictModelWrapper
implements Serializable {
    public final GenModel m;
    private final HashMap<String, Integer> modelColumnNameToIndexMap;
    public final HashMap<Integer, HashMap<String, Integer>> domainMap;
    private final boolean convertUnknownCategoricalLevelsToNa;
    private final boolean convertInvalidNumbersToNa;
    private final boolean useExtendedOutput;
    private final ConcurrentHashMap<String, AtomicLong> unknownCategoricalLevelsSeenPerColumn;

    public EasyPredictModelWrapper(Config config) {
        int i;
        this.m = config.getModel();
        this.modelColumnNameToIndexMap = new HashMap();
        String[] modelColumnNames = this.m.getNames();
        for (i = 0; i < modelColumnNames.length; ++i) {
            this.modelColumnNameToIndexMap.put(modelColumnNames[i], i);
        }
        this.unknownCategoricalLevelsSeenPerColumn = new ConcurrentHashMap();
        this.convertUnknownCategoricalLevelsToNa = config.getConvertUnknownCategoricalLevelsToNa();
        this.convertInvalidNumbersToNa = config.getConvertInvalidNumbersToNa();
        this.useExtendedOutput = config.getUseExtendedOutput();
        this.setupConvertUnknownCategoricalLevelsToNa();
        this.domainMap = new HashMap();
        for (i = 0; i < this.m.getNumCols(); ++i) {
            String[] domainValues = this.m.getDomainValues(i);
            if (domainValues == null) continue;
            HashMap<String, Integer> m = new HashMap<String, Integer>();
            for (int j = 0; j < domainValues.length; ++j) {
                m.put(domainValues[j], j);
            }
            this.domainMap.put(i, m);
        }
    }

    public EasyPredictModelWrapper(GenModel model) {
        this(new Config().setModel(model));
    }

    public long getTotalUnknownCategoricalLevelsSeen() {
        ConcurrentHashMap<String, AtomicLong> map = this.getUnknownCategoricalLevelsSeenPerColumn();
        long total = 0L;
        for (AtomicLong l : map.values()) {
            total += l.get();
        }
        return total;
    }

    public ConcurrentHashMap<String, AtomicLong> getUnknownCategoricalLevelsSeenPerColumn() {
        return this.unknownCategoricalLevelsSeenPerColumn;
    }

    public AbstractPrediction predict(RowData data, ModelCategory mc) throws PredictException {
        switch (mc) {
            case AutoEncoder: {
                return this.predictAutoEncoder(data);
            }
            case Binomial: {
                return this.predictBinomial(data);
            }
            case Multinomial: {
                return this.predictMultinomial(data);
            }
            case Clustering: {
                return this.predictClustering(data);
            }
            case Regression: {
                return this.predictRegression(data);
            }
            case DimReduction: {
                return this.predictDimReduction(data);
            }
            case WordEmbedding: {
                return this.predictWord2Vec(data);
            }
            case Unknown: {
                throw new PredictException("Unknown model category");
            }
        }
        throw new PredictException("Unhandled model category (" + (Object)((Object)this.m.getModelCategory()) + ") in switch statement");
    }

    public AbstractPrediction predict(RowData data) throws PredictException {
        return this.predict(data, this.m.getModelCategory());
    }

    public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException {
        this.validateModelCategory(ModelCategory.AutoEncoder);
        int size = this.m.getPredsSize(ModelCategory.AutoEncoder);
        double[] output = new double[size];
        double[] rawData = EasyPredictModelWrapper.nanArray(this.m.nfeatures());
        rawData = this.fillRawData(data, rawData);
        output = this.m.score0(rawData, output);
        AutoEncoderModelPrediction p = new AutoEncoderModelPrediction();
        p.original = this.expandRawData(rawData, output.length);
        p.reconstructed = output;
        p.reconstructedRowData = this.reconstructedToRowData(output);
        return p;
    }

    private double[] expandRawData(double[] data, int size) {
        double[] expanded = new double[size];
        int pos = 0;
        for (int i = 0; i < data.length; ++i) {
            if (this.m._domains[i] == null) {
                expanded[pos] = data[i];
                ++pos;
                continue;
            }
            int idx = Double.isNaN(data[i]) ? this.m._domains[i].length : (int)data[i];
            expanded[pos + idx] = 1.0;
            pos += this.m._domains[i].length + 1;
        }
        return expanded;
    }

    private RowData reconstructedToRowData(double[] reconstructed) {
        RowData rd = new RowData();
        int pos = 0;
        for (int i = 0; i < this.m.nfeatures(); ++i) {
            Object value;
            if (this.m._domains[i] == null) {
                value = reconstructed[pos++];
            } else {
                value = EasyPredictModelWrapper.catValuesAsMap(this.m._domains[i], reconstructed, pos);
                pos += this.m._domains[i].length + 1;
            }
            rd.put(this.m._names[i], value);
        }
        return rd;
    }

    private static Map<String, Double> catValuesAsMap(String[] cats, double[] reconstructed, int offset) {
        HashMap<String, Double> result = new HashMap<String, Double>(cats.length + 1);
        for (int i = 0; i < cats.length; ++i) {
            result.put(cats[i], reconstructed[i + offset]);
        }
        result.put(null, reconstructed[offset + cats.length]);
        return result;
    }

    public DimReductionModelPrediction predictDimReduction(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.DimReduction, data);
        DimReductionModelPrediction p = new DimReductionModelPrediction();
        p.dimensions = preds;
        return p;
    }

    public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException {
        this.validateModelCategory(ModelCategory.WordEmbedding);
        if (!(this.m instanceof WordEmbeddingModel)) {
            throw new PredictException("Model is not of the expected type, class = " + this.m.getClass().getSimpleName());
        }
        WordEmbeddingModel weModel = (WordEmbeddingModel)((Object)this.m);
        int vecSize = weModel.getVecSize();
        HashMap<String, float[]> embeddings = new HashMap<String, float[]>(data.size());
        for (String wordKey : data.keySet()) {
            Object value = data.get(wordKey);
            if (!(value instanceof String)) continue;
            String word = (String)value;
            embeddings.put(wordKey, weModel.transform0(word, new float[vecSize]));
        }
        Word2VecPrediction p = new Word2VecPrediction();
        p.wordEmbeddings = embeddings;
        return p;
    }

    public BinomialModelPrediction predictBinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Binomial, data);
        BinomialModelPrediction p = new BinomialModelPrediction();
        double d = preds[0];
        p.labelIndex = (int)d;
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        if (domainValues == null && this.m.getNumResponseClasses() == 2) {
            domainValues = new String[]{"0", "1"};
        }
        p.label = domainValues[p.labelIndex];
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        if (this.m.calibrateClassProbabilities(preds)) {
            p.calibratedClassProbabilities = new double[this.m.getNumResponseClasses()];
            System.arraycopy(preds, 1, p.calibratedClassProbabilities, 0, p.calibratedClassProbabilities.length);
        }
        return p;
    }

    public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Multinomial, data);
        MultinomialModelPrediction p = new MultinomialModelPrediction();
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        p.labelIndex = (int)preds[0];
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        p.label = domainValues[p.labelIndex];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        return p;
    }

    private SortedClassProbability[] sortByDescendingClassProbability(String[] domainValues, double[] classProbabilities) {
        assert (classProbabilities.length == domainValues.length);
        SortedClassProbability[] arr = new SortedClassProbability[domainValues.length];
        for (int i = 0; i < domainValues.length; ++i) {
            arr[i] = new SortedClassProbability();
            arr[i].name = domainValues[i];
            arr[i].probability = classProbabilities[i];
        }
        Arrays.sort(arr, Collections.reverseOrder());
        return arr;
    }

    public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction p) {
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        double[] classProbabilities = p.classProbabilities;
        return this.sortByDescendingClassProbability(domainValues, classProbabilities);
    }

    public SortedClassProbability[] sortByDescendingClassProbability(MultinomialModelPrediction p) {
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        double[] classProbabilities = p.classProbabilities;
        return this.sortByDescendingClassProbability(domainValues, classProbabilities);
    }

    public ClusteringModelPrediction predictClustering(RowData data) throws PredictException {
        ClusteringModelPrediction p = new ClusteringModelPrediction();
        if (this.useExtendedOutput && this.m instanceof IClusteringModel) {
            IClusteringModel cm = (IClusteringModel)((Object)this.m);
            double[] rawData = EasyPredictModelWrapper.nanArray(this.m.nfeatures());
            rawData = this.fillRawData(data, rawData);
            int k = cm.getNumClusters();
            p.distances = new double[k];
            p.cluster = cm.distances(rawData, p.distances);
        } else {
            double[] preds = this.preamble(ModelCategory.Clustering, data);
            p.cluster = (int)preds[0];
        }
        return p;
    }

    public RegressionModelPrediction predictRegression(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Regression, data);
        RegressionModelPrediction p = new RegressionModelPrediction();
        p.value = preds[0];
        return p;
    }

    public ModelCategory getModelCategory() {
        return this.m.getModelCategory();
    }

    public String[] getResponseDomainValues() {
        return this.m.getDomainValues(this.m.getResponseIdx());
    }

    public String getHeader() {
        return this.m.getHeader();
    }

    private void setupConvertUnknownCategoricalLevelsToNa() {
        if (this.convertUnknownCategoricalLevelsToNa) {
            for (int i = 0; i < this.m.getNumCols(); ++i) {
                String[] domainValues = this.m.getDomainValues(i);
                if (domainValues == null) continue;
                String columnName = this.m.getNames()[i];
                this.unknownCategoricalLevelsSeenPerColumn.put(columnName, new AtomicLong());
            }
        } else {
            this.unknownCategoricalLevelsSeenPerColumn.clear();
        }
    }

    private void validateModelCategory(ModelCategory c) throws PredictException {
        if (!this.m.getModelCategories().contains((Object)c)) {
            throw new PredictException((Object)((Object)c) + " prediction type is not supported for this model.");
        }
    }

    protected double[] preamble(ModelCategory c, RowData data) throws PredictException {
        this.validateModelCategory(c);
        return this.predict(data, new double[this.m.getPredsSize(c)]);
    }

    private static double[] nanArray(int len) {
        double[] arr = new double[len];
        for (int i = 0; i < len; ++i) {
            arr[i] = Double.NaN;
        }
        return arr;
    }

    /*
     * Unable to fully structure code
     */
    protected double[] fillRawData(RowData data, double[] rawData) throws PredictException {
        isImage = this.m instanceof DeepwaterMojoModel != false && ((DeepwaterMojoModel)this.m)._problem_type.equals("image") != false;
        isText = this.m instanceof DeepwaterMojoModel != false && ((DeepwaterMojoModel)this.m)._problem_type.equals("text") != false;
        for (String dataColumnName : data.keySet()) {
            block27: {
                block26: {
                    index = this.modelColumnNameToIndexMap.get(dataColumnName);
                    if (index == null || index >= rawData.length) continue;
                    img = null;
                    domainValues = this.m.getDomainValues(index);
                    if (domainValues == null) {
                        value = NaN;
                        o = data.get(dataColumnName);
                        if (o instanceof String) {
                            s = ((String)o).trim();
                            if (isImage) {
                                isURL = s.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]");
                                try {
                                    img = isURL != false ? ImageIO.read(new URL(s)) : ImageIO.read(new File(s));
                                }
                                catch (IOException e) {
                                    throw new PredictException("Couldn't read image from " + s);
                                }
                            } else {
                                if (isText) {
                                    throw new PredictException("MOJO scoring for text classification is not yet implemented.");
                                }
                                try {
                                    value = Double.parseDouble(s);
                                }
                                catch (NumberFormatException nfe) {
                                    if (this.convertInvalidNumbersToNa) ** GOTO lbl39
                                    throw new PredictNumberFormatException("Unable to parse value: " + s + ", from column: " + dataColumnName + ", as Double; " + nfe.getMessage());
                                }
                            }
                        } else if (o instanceof Double) {
                            value = (Double)o;
                        } else if (o instanceof byte[] && isImage) {
                            is = new ByteArrayInputStream((byte[])o);
                            try {
                                img = ImageIO.read(is);
                            }
                            catch (IOException e) {
                                throw new PredictException("Couldn't interpret raw bytes as an image.");
                            }
                        } else {
                            throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for numeric column " + dataColumnName);
                        }
lbl39:
                        // 5 sources

                        if (isImage && img != null) {
                            dwm = (DeepwaterMojoModel)this.m;
                            W = dwm._width;
                            H = dwm._height;
                            C = dwm._channels;
                            _destData = new float[W * H * C];
                            try {
                                GenModel.img2pixels(img, W, H, C, _destData, 0, dwm._meanImageData);
                            }
                            catch (IOException e) {
                                e.printStackTrace();
                                throw new PredictException("Couldn't vectorize image.");
                            }
                            rawData = new double[_destData.length];
                            for (i = 0; i < rawData.length; ++i) {
                                rawData[i] = _destData[i];
                            }
                            return rawData;
                        }
                        rawData[index.intValue()] = value;
                        continue;
                    }
                    o = data.get(dataColumnName);
                    if (!(o instanceof String)) break block26;
                    levelName = (String)o;
                    columnDomainMap = this.domainMap.get(index);
                    levelIndex = columnDomainMap.get(levelName);
                    if (levelIndex == null) {
                        levelIndex = columnDomainMap.get(dataColumnName + "." + levelName);
                    }
                    if (levelIndex != null) ** GOTO lbl72
                    if (this.convertUnknownCategoricalLevelsToNa) {
                        value = NaN;
                        this.unknownCategoricalLevelsSeenPerColumn.get(dataColumnName).incrementAndGet();
                    } else {
                        throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName);
lbl72:
                        // 1 sources

                        value = levelIndex.intValue();
                    }
                    break block27;
                }
                if (o instanceof Double && Double.isNaN((Double)o)) {
                    value = (Double)o;
                } else {
                    throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for categorical column " + dataColumnName);
                }
            }
            rawData[index.intValue()] = value;
        }
        return rawData;
    }

    protected double[] predict(RowData data, double[] preds) throws PredictException {
        double[] rawData = EasyPredictModelWrapper.nanArray(this.m.nfeatures());
        rawData = this.fillRawData(data, rawData);
        preds = this.m.score0(rawData, preds);
        return preds;
    }

    public static class Config {
        private GenModel model;
        private boolean convertUnknownCategoricalLevelsToNa = false;
        private boolean convertInvalidNumbersToNa = false;
        private boolean useExtendedOutput = false;

        public Config setModel(GenModel value) {
            this.model = value;
            return this;
        }

        public GenModel getModel() {
            return this.model;
        }

        public Config setConvertUnknownCategoricalLevelsToNa(boolean value) {
            this.convertUnknownCategoricalLevelsToNa = value;
            return this;
        }

        public boolean getConvertUnknownCategoricalLevelsToNa() {
            return this.convertUnknownCategoricalLevelsToNa;
        }

        public Config setConvertInvalidNumbersToNa(boolean value) {
            this.convertInvalidNumbersToNa = value;
            return this;
        }

        public boolean getConvertInvalidNumbersToNa() {
            return this.convertInvalidNumbersToNa;
        }

        public Config setUseExtendedOutput(boolean value) {
            this.useExtendedOutput = value;
            return this;
        }

        public boolean getUseExtendedOutput() {
            return this.useExtendedOutput;
        }
    }
}

