package water.util;

import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Type;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.StringUtils;

/* loaded from: input_file:water/util/H2OPredictor.class */
public class H2OPredictor {
    private static final boolean DEBUG = false;
    private String[] labels;
    private EasyPredictModelWrapper model = null;
    private final Type MapType = new TypeToken<Map<String, Object>>() { // from class: water.util.H2OPredictor.1
    }.getType();
    private static boolean useLabels = false;
    private static final Gson gson = new Gson();

    public H2OPredictor(String str, String str2) {
        this.labels = null;
        try {
            if (str == null) {
                throw new Exception("file name can't be null");
            }
            if (str.endsWith(".jar")) {
                loadPojo(str, str2);
            } else {
                if (!str.endsWith(".zip")) {
                    throw new Exception("unknown model archive type");
                }
                loadMojo(str);
            }
            if (useLabels) {
                this.labels = this.model.getResponseDomainValues();
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    private GenModel loadClassFromJar(String str, String str2) throws Exception {
        if (!new File(str).isFile()) {
            throw new FileNotFoundException("Can't read " + str);
        }
        try {
            return (GenModel) URLClassLoader.newInstance(new URL[]{new File(str).toURI().toURL()}, getClass().getClassLoader()).loadClass(StringUtils.EMPTY + str2).asSubclass(GenModel.class).newInstance();
        } catch (ClassNotFoundException e) {
            throw new Exception("Can't find model " + str2 + " in jar file " + str);
        } catch (IllegalAccessException e2) {
            throw new Exception("Can't find model " + str2 + " in jar file " + str);
        } catch (InstantiationException e3) {
            throw new Exception("Can't find model " + str2 + " in jar file " + str);
        } catch (MalformedURLException e4) {
            throw new Exception("Can't use Jar file" + str);
        }
    }

    private void loadPojo(String str, String str2) throws Exception {
        this.model = new EasyPredictModelWrapper(loadClassFromJar(str, str2));
    }

    private void loadMojo(String str) throws Exception {
        this.model = new EasyPredictModelWrapper(MojoModel.load(str));
    }

    private RowData jsonToRowData(String str) {
        try {
            return (RowData) gson.fromJson(str, RowData.class);
        } catch (JsonSyntaxException e) {
            throw new JsonSyntaxException("Malformed JSON");
        }
    }

    private RowData[] jsonToRowDataArray(String str) {
        try {
            return (RowData[]) gson.fromJson(str, RowData[].class);
        } catch (JsonSyntaxException e) {
            throw new JsonSyntaxException("Malformed JSON Array");
        }
    }

    private String predictRow(RowData rowData) throws PredictException {
        if (this.model == null) {
            throw new PredictException("No model loaded");
        }
        if (gson == null) {
            throw new PredictException("Gson not available");
        }
        if (rowData == null) {
            throw new PredictException("No row data");
        }
        String json = gson.toJson(this.model.predict(rowData));
        if (useLabels) {
            Map map = (Map) gson.fromJson(json, this.MapType);
            map.put("responseDomainValues", this.labels);
            json = gson.toJson(map);
        }
        return json;
    }

    public static String predict3(String str, String str2, String str3) {
        String predictRow;
        try {
            H2OPredictor h2OPredictor = new H2OPredictor(str, str2);
            if (str == null) {
                throw new Exception("file name can't be null");
            }
            if (str.endsWith(".jar")) {
                h2OPredictor.loadPojo(str, str2);
            } else {
                if (!str.endsWith(".zip")) {
                    throw new Exception("unknown model archive type");
                }
                h2OPredictor.loadMojo(str);
            }
            if (str3 == null || str3.length() == 0) {
                throw new Exception("empty json argument");
            }
            char charAt = str3.trim().charAt(0);
            if (!(charAt == '{' || charAt == '[')) {
                str3 = new String(readFile(str3));
                charAt = str3.trim().charAt(0);
                boolean z = charAt == '{' || charAt == '[';
            }
            if (charAt == '[') {
                RowData[] jsonToRowDataArray = h2OPredictor.jsonToRowDataArray(str3);
                String str4 = StringUtils.EMPTY + "[ ";
                for (RowData rowData : jsonToRowDataArray) {
                    if (!str4.trim().endsWith("[")) {
                        str4 = str4 + ", ";
                    }
                    str4 = str4 + h2OPredictor.predictRow(rowData);
                }
                predictRow = str4 + " ]";
            } else {
                predictRow = h2OPredictor.predictRow(h2OPredictor.jsonToRowData(str3));
            }
            return predictRow;
        } catch (Exception e) {
            HashMap hashMap = new HashMap();
            hashMap.put("error", stackTraceToString(e));
            return gson.toJson(hashMap);
        }
    }

    public String pred(String str) {
        try {
            return predictRow(jsonToRowData(str));
        } catch (Exception e) {
            return "{ \"error\": \"" + stackTraceToString(e) + "\" }";
        }
    }

    public static String predict2(String str, String str2) {
        String replace = str.replace(".zip", StringUtils.EMPTY).replace(".jar", StringUtils.EMPTY);
        int lastIndexOf = replace.lastIndexOf(File.separatorChar);
        if (lastIndexOf != -1) {
            replace = replace.substring(lastIndexOf + 1);
        }
        return predict3(str, replace, str2);
    }

    private static byte[] readFile(String str) throws IOException {
        StringBuffer stringBuffer = new StringBuffer();
        BufferedReader bufferedReader = null;
        try {
            bufferedReader = new BufferedReader(new FileReader(str));
            char[] cArr = new char[1024];
            while (true) {
                int read = bufferedReader.read(cArr);
                if (read == -1) {
                    break;
                }
                stringBuffer.append(String.valueOf(cArr, 0, read));
            }
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            return stringBuffer.toString().getBytes();
        } catch (Throwable th) {
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            throw th;
        }
    }

    private static String stackTraceToString(Throwable th) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        PrintStream printStream = new PrintStream(byteArrayOutputStream);
        th.printStackTrace(printStream);
        String byteArrayOutputStream2 = byteArrayOutputStream.toString();
        try {
            printStream.close();
            byteArrayOutputStream.close();
            return byteArrayOutputStream2;
        } catch (IOException e) {
            return "Can't get stack trace from throwable " + th.getMessage();
        }
    }

    public static void main(String[] strArr) {
        if (strArr.length > 0 && strArr[0].equals("-l")) {
            useLabels = true;
            strArr = (String[]) Arrays.copyOfRange(strArr, 1, strArr.length);
        }
        System.out.println(strArr.length == 2 ? predict2(strArr[0], strArr[1].replaceAll("\\\\", StringUtils.EMPTY)) : strArr.length == 3 ? predict3(strArr[0], strArr[1], strArr[2].replaceAll("\\\\", StringUtils.EMPTY)) : "{ \"error\": \"Neeed 2 or 3 args have " + strArr.length + ", \"usage\": \"mojoFile jsonString  or: jarFile modelName jsonString\" } ");
    }
}
