package hex.deepwater;

import com.feedzai.openml.h2o.algos.H2ODeepLearningUtils;
import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.DataInfo;
import hex.deepwater.DeepWaterParameters;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Arrays;
import javassist.compiler.TokenId;
import org.apache.commons.lang.StringUtils;
import water.H2O;
import water.Iced;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/deepwater/DeepWaterModelInfo.class */
public final class DeepWaterModelInfo extends Iced {
    private int _classes;
    byte[] _network;
    byte[] _modelparams;
    private TwoDimTable summaryTable;
    transient BackendTrain _backend;
    int _height;
    int _width;
    int _channels;
    float[] _meanData;
    DataInfo _dataInfo;
    public DeepWaterParameters parameters;
    private long processed_global;
    private long processed_local;
    private final boolean _classification;
    static final /* synthetic */ boolean $assertionsDisabled;
    transient ThreadLocal<BackendModel> _model = new ThreadLocal<>();
    volatile boolean _unstable = false;

    public ThreadLocal<BackendModel> getModel() {
        if (null == this._model) {
            this._model = new ThreadLocal<>();
        }
        return this._model;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void nukeModel() {
        if (this._backend != null && getModel() != null && getModel().get() != null) {
            this._backend.delete(getModel().get());
        }
        getModel().set(null);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void nukeBackend() {
        nukeModel();
        this._backend = null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void saveNativeState(String str, int i) {
        if (!$assertionsDisabled && this._backend == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && getModel() == null) {
            throw new AssertionError();
        }
        this._backend.saveModel(getModel().get(), str + ".json");
        this._backend.saveParam(getModel().get(), str + "." + i + ".params");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float[] predict(float[] fArr) {
        if (!$assertionsDisabled && this._backend == null) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || getModel() != null) {
            return this._backend.predict(getModel().get(), fArr);
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float[] extractLayer(String str, float[] fArr) {
        if (!$assertionsDisabled && this._backend == null) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || getModel() != null) {
            return this._backend.extractLayer(getModel().get(), str, fArr);
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String listAllLayers() {
        if (!$assertionsDisabled && this._backend == null) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || getModel() != null) {
            return this._backend.listAllLayers(getModel().get());
        }
        throw new AssertionError();
    }

    public int hashCode() {
        return Arrays.hashCode(this._network) + Arrays.hashCode(this._modelparams);
    }

    public long size() {
        long j = 0;
        if (this._network != null) {
            j = 0 + this._network.length;
        }
        if (this._modelparams != null) {
            j += this._modelparams.length;
        }
        return j;
    }

    public final DeepWaterParameters get_params() {
        return this.parameters;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized long get_processed_global() {
        return this.processed_global;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized void set_processed_global(long j) {
        this.processed_global = j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized void add_processed_global(long j) {
        this.processed_global += j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized long get_processed_local() {
        return this.processed_local;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized void set_processed_local(long j) {
        this.processed_local = j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized void add_processed_local(long j) {
        this.processed_local += j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized long get_processed_total() {
        return this.processed_global + this.processed_local;
    }

    private RuntimeOptions getRuntimeOptions() {
        RuntimeOptions runtimeOptions = new RuntimeOptions();
        runtimeOptions.setSeed((int) get_params().getOrMakeRealSeed());
        runtimeOptions.setUseGPU(get_params()._gpu);
        runtimeOptions.setDeviceID(get_params()._device_id);
        return runtimeOptions;
    }

    private BackendParams getBackendParams() {
        Object obj;
        BackendParams backendParams = new BackendParams();
        backendParams.set("mini_batch_size", Integer.valueOf(get_params()._mini_batch_size));
        backendParams.set("clip_gradient", Double.valueOf(get_params()._clip_gradient));
        if ((this.parameters._network == null ? null : this.parameters._network.toString()) == null) {
            if (!$assertionsDisabled && this.parameters._activation == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this.parameters._hidden == null) {
                throw new AssertionError();
            }
            String[] strArr = new String[this.parameters._hidden.length];
            if (this.parameters._activation.toString().startsWith("Rectifier")) {
                obj = "relu";
            } else {
                if (!this.parameters._activation.toString().startsWith("Tanh")) {
                    throw H2O.unimpl();
                }
                obj = "tanh";
            }
            Arrays.fill(strArr, obj);
            backendParams.set("activations", strArr);
            backendParams.set(H2ODeepLearningUtils.HIDDEN, this.parameters._hidden);
            backendParams.set("input_dropout_ratio", Double.valueOf(this.parameters._input_dropout_ratio));
            backendParams.set("hidden_dropout_ratios", this.parameters._hidden_dropout_ratios);
        }
        return backendParams;
    }

    private ImageDataSet getImageDataSet() {
        return new ImageDataSet(this._width, this._height, this._channels, this._classes);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepWaterModelInfo(DeepWaterParameters deepWaterParameters, int i, int i2) {
        this._classes = i;
        this._classification = this._classes > 1;
        this.parameters = (DeepWaterParameters) deepWaterParameters.m1957clone();
        this._width = i2;
        this._height = 0;
        this._channels = 0;
        if (this.parameters._problem_type == DeepWaterParameters.ProblemType.image) {
            this._width = this.parameters._image_shape[0];
            this._height = this.parameters._image_shape[1];
            this._channels = this.parameters._channels;
            if (this._width == 0 || this._height == 0) {
                switch (this.parameters._network) {
                    case lenet:
                        this._width = 28;
                        this._height = 28;
                        break;
                    case auto:
                    case alexnet:
                    case googlenet:
                    case resnet:
                        this._width = 224;
                        this._height = 224;
                        break;
                    case inception_bn:
                        this._width = 299;
                        this._height = 299;
                        break;
                    case vgg:
                        this._width = TokenId.IF;
                        this._height = TokenId.IF;
                        break;
                    case user:
                        throw new H2OIllegalArgumentException("Please specify width and height for user-given model definition.");
                    default:
                        throw H2O.unimpl("Unknown network type: " + this.parameters._network);
                }
            }
            if (!$assertionsDisabled && this._width <= 0) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._height <= 0) {
                throw new AssertionError();
            }
        } else if (this.parameters._problem_type == DeepWaterParameters.ProblemType.dataset) {
            if (this.parameters._image_shape != null) {
                if (this.parameters._image_shape[0] > 0) {
                    this._width = this.parameters._image_shape[0];
                }
                if (this.parameters._image_shape[1] > 0) {
                    this._height = this.parameters._image_shape[1];
                }
                if (this._width <= 0 || this._height <= 0) {
                    this._channels = 0;
                } else {
                    this._channels = this.parameters._channels;
                }
            }
        } else {
            if (this.parameters._problem_type != DeepWaterParameters.ProblemType.text) {
                Log.warn("unknown problem_type:", this.parameters._problem_type);
                throw H2O.unimpl();
            }
            this._width = 56;
        }
        setupNativeBackend();
    }

    private void setupNativeBackend() {
        try {
            this._backend = DeepwaterMojoModel.createDeepWaterBackend(this.parameters._backend.toString());
            if (this._backend == null) {
                throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
            }
            ImageDataSet imageDataSet = getImageDataSet();
            RuntimeOptions runtimeOptions = getRuntimeOptions();
            BackendParams backendParams = getBackendParams();
            if (this.parameters._network != DeepWaterParameters.Network.user) {
                String network = this.parameters._network == null ? null : this.parameters._network.toString();
                if (network != null) {
                    Log.info("Creating a fresh model of the following network type: " + network);
                    getModel().set(this._backend.buildNet(imageDataSet, runtimeOptions, backendParams, this._classes, network));
                } else {
                    Log.info("Creating a fresh model of the following network type: MLP");
                    getModel().set(this._backend.buildNet(imageDataSet, runtimeOptions, backendParams, this._classes, "MLP"));
                }
            }
            String str = this.parameters._network_definition_file;
            if (str != null && !str.isEmpty()) {
                File file = new File(str);
                if (!file.exists() || file.isDirectory()) {
                    throw new RuntimeException("Network definition file " + file + " not found.");
                }
                Log.info("Loading the network from: " + file.getAbsolutePath());
                Log.info("Setting the optimizer and initializing the first and last layer.");
                getModel().set(this._backend.buildNet(imageDataSet, runtimeOptions, backendParams, this._classes, file.getAbsolutePath()));
            }
            if (this.parameters._mean_image_file != null && !this.parameters._mean_image_file.isEmpty()) {
                imageDataSet.setMeanData(this._backend.loadMeanImage(getModel().get(), this.parameters._mean_image_file));
            }
            this._meanData = imageDataSet.getMeanData();
            String str2 = this.parameters._network_parameters_file;
            if (str2 == null || str2.isEmpty()) {
                Log.warn("No network parameters file specified. Starting from scratch.");
            } else {
                File file2 = new File(str2);
                if (!paramFilesExist(str2)) {
                    throw new RuntimeException("Network parameter file " + file2 + " not found.");
                }
                Log.info("Loading the parameters (weights/biases) from: " + file2.getAbsolutePath());
                if (!$assertionsDisabled && getModel() == null) {
                    throw new AssertionError();
                }
                this._backend.loadParam(getModel().get(), file2.getAbsolutePath());
            }
            nativeToJava();
        } catch (Throwable th) {
            throw new RuntimeException("Unable to initialize the native Deep Learning backend: " + th.getMessage());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean paramFilesExist(String str) {
        final File file = new File(str);
        String[] list = file.getParentFile().list(new FilenameFilter() { // from class: hex.deepwater.DeepWaterModelInfo.1
            @Override // java.io.FilenameFilter
            public boolean accept(File file2, String str2) {
                return str2.contains(file.getName());
            }
        });
        return !file.isDirectory() && (file.exists() || (list != null && list.length > 0));
    }

    String getBasePath() {
        return System.getProperty("java.io.tmpdir");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void nativeToJava() {
        if (this._backend == null) {
            return;
        }
        Log.info("Native backend -> Java.");
        long currentTimeMillis = System.currentTimeMillis();
        File file = null;
        if (this._network == null) {
            try {
                file = new File(getBasePath(), Key.make().toString());
                this._backend.saveModel(getModel().get(), file.toString());
                FileInputStream fileInputStream = new FileInputStream(file);
                this._network = new byte[(int) file.length()];
                fileInputStream.read(this._network);
                fileInputStream.close();
                if (file != null) {
                    this._backend.deleteSavedModel(file.toString());
                }
            } catch (IOException e) {
                e.printStackTrace();
                if (file != null) {
                    this._backend.deleteSavedModel(file.toString());
                }
            } finally {
                if (file != null) {
                    this._backend.deleteSavedModel(file.toString());
                }
            }
        }
        try {
            file = new File(getBasePath(), Key.make().toString());
            this._backend.saveParam(getModel().get(), file.toString());
            this._modelparams = this._backend.readBytes(file);
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        } catch (IOException e2) {
            e2.printStackTrace();
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        } finally {
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        }
        Log.info("Took: " + PrettyPrint.msecs(System.currentTimeMillis() - currentTimeMillis, true));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void javaToNative() {
        javaToNative(null, null);
    }

    private void javaToNative(byte[] bArr, byte[] bArr2) {
        long currentTimeMillis = System.currentTimeMillis();
        if (this._backend != null && ((bArr == null || Arrays.equals(bArr, this._network)) && (bArr2 == null || Arrays.equals(bArr2, this._modelparams)))) {
            Log.warn("No need to move the state from Java to native.");
            return;
        }
        if (this._backend == null) {
            this._backend = DeepwaterMojoModel.createDeepWaterBackend(get_params()._backend.toString());
            if (this._backend == null) {
                throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
            }
        }
        if (bArr == null) {
            bArr = this._network;
        }
        if (bArr2 == null) {
            bArr2 = this._modelparams;
        }
        if (bArr == null || bArr2 == null) {
            return;
        }
        Log.info("Java state -> native backend.");
        initModel(bArr, bArr2);
        Log.info("Took: " + PrettyPrint.msecs(System.currentTimeMillis() - currentTimeMillis, true));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initModel() {
        initModel(this._network, this._modelparams);
    }

    private void initModel(byte[] bArr, byte[] bArr2) {
        File file = null;
        try {
            file = new File(getBasePath(), Key.make().toString() + ".json");
            FileOutputStream fileOutputStream = new FileOutputStream(file);
            fileOutputStream.write(bArr);
            fileOutputStream.close();
            getModel().set(this._backend.buildNet(getImageDataSet(), getRuntimeOptions(), getBackendParams(), this._classes, file.toString()));
            if (file != null) {
                this._backend.deleteSavedModel(file.toString());
            }
        } catch (IOException e) {
            e.printStackTrace();
            if (file != null) {
                this._backend.deleteSavedModel(file.toString());
            }
        } finally {
            if (file != null) {
                this._backend.deleteSavedModel(file.toString());
            }
        }
        try {
            file = new File(System.getProperty("java.io.tmpdir"), Key.make().toString());
            this._backend.writeBytes(file, bArr2);
            this._backend.loadParam(getModel().get(), file.toString());
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        } catch (IOException e2) {
            e2.printStackTrace();
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        } finally {
            if (file != null) {
                this._backend.deleteSavedParam(file.toString());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TwoDimTable createSummaryTable() {
        TwoDimTable twoDimTable = new TwoDimTable("Status of Deep Learning Model", (get_params()._network == null ? "MLP: " + Arrays.toString(get_params()._hidden) : get_params()._network.toString()) + ", " + PrettyPrint.bytes(size()) + ", " + (!get_params()._autoencoder ? "predicting " + get_params()._response_column + ", " : StringUtils.EMPTY) + (get_params()._autoencoder ? "auto-encoder" : this._classification ? this._classes + "-class classification" : "regression") + ", " + String.format("%,d", Long.valueOf(get_processed_global())) + " training samples, mini-batch size " + String.format("%,d", Integer.valueOf(get_params()._mini_batch_size)), new String[1], new String[]{"Input Neurons", "Rate", "Momentum"}, new String[]{"int", "double", "double"}, new String[]{"%d", "%5f", "%5f"}, StringUtils.EMPTY);
        twoDimTable.set(0, 0, Integer.valueOf(this._dataInfo != null ? this._dataInfo.fullN() : this._width * this._height * this._channels));
        twoDimTable.set(0, 1, Float.valueOf(get_params().learningRate(get_processed_global())));
        twoDimTable.set(0, 2, Float.valueOf(get_params().momentum(get_processed_global())));
        this.summaryTable = twoDimTable;
        return this.summaryTable;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (!get_params()._quiet_mode) {
            createSummaryTable();
            if (this.summaryTable != null) {
                sb.append(this.summaryTable.toString(1));
            }
        }
        return sb.toString();
    }

    public String toStringAll() {
        StringBuilder sb = new StringBuilder();
        sb.append(toString());
        sb.append("\nprocessed global: ").append(get_processed_global());
        sb.append("\nprocessed local:  ").append(get_processed_local());
        sb.append("\nprocessed total:  ").append(get_processed_total());
        sb.append("\n");
        return sb.toString();
    }

    public void add(DeepWaterModelInfo deepWaterModelInfo) {
        throw H2O.unimpl();
    }

    public void mult(double d) {
        throw H2O.unimpl();
    }

    public void div(double d) {
        throw H2O.unimpl();
    }

    static {
        $assertionsDisabled = !DeepWaterModelInfo.class.desiredAssertionStatus();
    }
}
