package de.datexis.tagger;

import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/tagger/Tagger.class */
public abstract class Tagger extends AnnotatorComponent {
    protected static final Logger log = LoggerFactory.getLogger(Tagger.class);
    protected long inputVectorSize;
    protected long embeddingVectorSize;
    protected long outputVectorSize;
    private List<Encoder> encoders;
    protected int batchSize;
    protected int maxTimeSeriesLength;
    protected int numExamples;
    protected int numEpochs;
    protected boolean randomize;
    protected int embeddingLayerSize;
    protected Model net;

    public Tagger(String str) {
        super(false);
        this.encoders = new ArrayList();
        this.batchSize = 16;
        this.maxTimeSeriesLength = -1;
        this.numExamples = -1;
        this.numEpochs = 1;
        this.randomize = true;
        this.id = str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tagger(long j, long j2) {
        super(false);
        this.encoders = new ArrayList();
        this.batchSize = 16;
        this.maxTimeSeriesLength = -1;
        this.numExamples = -1;
        this.numEpochs = 1;
        this.randomize = true;
        this.inputVectorSize = j;
        this.outputVectorSize = j2;
    }

    protected Tagger(Resource resource) {
        super(false);
        this.encoders = new ArrayList();
        this.batchSize = 16;
        this.maxTimeSeriesLength = -1;
        this.numExamples = -1;
        this.numEpochs = 1;
        this.randomize = true;
        loadModel(resource);
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public void setNumEpochs(int i) {
        this.numEpochs = i;
    }

    public boolean isRandomize() {
        return this.randomize;
    }

    public void setRandomize(boolean z) {
        this.randomize = z;
    }

    public int getMaxTimeSeriesLength() {
        return this.maxTimeSeriesLength;
    }

    public void setMaxTimeSeriesLength(int i) {
        this.maxTimeSeriesLength = i;
    }

    public int getEmbeddingLayerSize() {
        return this.embeddingLayerSize;
    }

    public void setEmbeddingLayerSize(int i) {
        this.embeddingLayerSize = i;
    }

    @JsonIgnore
    public Model getNN() {
        return this.net;
    }

    @Override // de.datexis.annotator.IComponent
    public void setEncoders(List<Encoder> list) {
        this.encoders = list;
        long j = 0;
        Iterator<Encoder> it = list.iterator();
        while (it.hasNext()) {
            j += it.next().getEmbeddingVectorSize();
        }
        this.inputVectorSize = j;
    }

    @Override // de.datexis.annotator.IComponent
    @JsonIgnore
    public List<Encoder> getEncoders() {
        return this.encoders;
    }

    public ComputationGraphConfiguration getGraphConfiguration() {
        if (this.net != null && (this.net instanceof ComputationGraph)) {
            return this.net.getConfiguration();
        }
        return null;
    }

    public void setGraphConfiguration(JsonNode jsonNode) {
        String jsonNode2;
        if (jsonNode == null || (jsonNode2 = jsonNode.toString()) == null || jsonNode2.equals("null")) {
            return;
        }
        this.net = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonNode2));
        this.net.init();
    }

    public MultiLayerConfiguration getLayerConfiguration() {
        if (this.net != null && (this.net instanceof MultiLayerNetwork)) {
            return this.net.getLayerWiseConfigurations();
        }
        return null;
    }

    public void setLayerConfiguration(JsonNode jsonNode) {
        String jsonNode2;
        if (jsonNode == null || (jsonNode2 = jsonNode.toString()) == null || jsonNode2.equals("null")) {
            return;
        }
        this.net = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonNode2));
        this.net.init();
    }

    public void setListeners(IterationListener... iterationListenerArr) {
        this.net.setListeners(iterationListenerArr);
    }

    @Override // de.datexis.annotator.AnnotatorComponent
    @JsonIgnore
    public boolean isModelAvailableInChildren() {
        return this.encoders.stream().allMatch(encoder -> {
            return encoder.isModelAvailable();
        });
    }

    public void setTrainingParams(int i, int i2, int i3, int i4, boolean z) {
        this.numExamples = i;
        this.maxTimeSeriesLength = i2;
        this.batchSize = i3;
        this.numEpochs = i4;
        this.randomize = z;
    }

    @Override // de.datexis.annotator.IComponent
    public void saveModel(Resource resource, String str) {
        if (this.net instanceof ComputationGraph) {
            Resource resolve = resource.resolve(str + ".zip");
            try {
                OutputStream outputStream = resolve.getOutputStream();
                Throwable th = null;
                try {
                    try {
                        ModelSerializer.writeModel(this.net, outputStream, true);
                        setModel(resolve);
                        if (outputStream != null) {
                            if (0 != 0) {
                                try {
                                    outputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                outputStream.close();
                            }
                        }
                        return;
                    } catch (Throwable th3) {
                        th = th3;
                        throw th3;
                    }
                } finally {
                }
            } catch (IOException e) {
                log.error(e.toString());
                return;
            }
        }
        if (this.net instanceof MultiLayerNetwork) {
            Resource resolve2 = resource.resolve(str + ".bin.gz");
            try {
                DataOutputStream dataOutputStream = new DataOutputStream(resolve2.getGZIPOutputStream());
                Throwable th4 = null;
                try {
                    try {
                        Nd4j.write(this.net.params(), dataOutputStream);
                        dataOutputStream.flush();
                        setModel(resolve2);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th5) {
                                    th4.addSuppressed(th5);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                    } catch (Throwable th6) {
                        th4 = th6;
                        throw th6;
                    }
                } finally {
                }
            } catch (IOException e2) {
                log.error(e2.toString());
            }
        }
    }

    @Override // de.datexis.annotator.IComponent
    public void loadModel(Resource resource) {
        if (resource.getFileName().endsWith("zip")) {
            try {
                InputStream inputStream = resource.getInputStream();
                Throwable th = null;
                try {
                    try {
                        this.net = ModelSerializer.restoreComputationGraph(inputStream, true);
                        setModel(resource);
                        setModelAvailable(true);
                        log.info("loaded ComputationGraph from " + resource.getFileName());
                        if (inputStream != null) {
                            if (0 != 0) {
                                try {
                                    inputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                inputStream.close();
                            }
                        }
                        return;
                    } finally {
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } catch (IOException e) {
                log.error(e.toString());
                return;
            }
        }
        try {
            DataInputStream dataInputStream = new DataInputStream(resource.getInputStream());
            Throwable th4 = null;
            try {
                try {
                    this.net.setParameters(Nd4j.read(dataInputStream));
                    setModel(resource);
                    setModelAvailable(true);
                    log.info("loaded MultiLayerNetwork from " + resource.getFileName());
                    if (dataInputStream != null) {
                        if (0 != 0) {
                            try {
                                dataInputStream.close();
                            } catch (Throwable th5) {
                                th4.addSuppressed(th5);
                            }
                        } else {
                            dataInputStream.close();
                        }
                    }
                } catch (Throwable th6) {
                    th4 = th6;
                    throw th6;
                }
            } finally {
            }
        } catch (IOException e2) {
            log.error(e2.toString());
        }
    }

    @Deprecated
    public void saveUpdater(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".bin.gz");
        INDArray iNDArray = null;
        if (this.net instanceof MultiLayerNetwork) {
            iNDArray = this.net.getUpdater().getStateViewArray();
        } else if (this.net instanceof ComputationGraph) {
            iNDArray = this.net.getUpdater().getStateViewArray();
        }
        if (iNDArray != null) {
            try {
                DataOutputStream dataOutputStream = new DataOutputStream(resolve.getGZIPOutputStream());
                Throwable th = null;
                try {
                    Nd4j.write(iNDArray, dataOutputStream);
                    dataOutputStream.flush();
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                } finally {
                }
            } catch (IOException e) {
                log.error(e.toString());
            }
        }
    }

    public void loadConf(Resource resource) {
        try {
            this.net = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(IOUtils.toString(resource.getInputStream())));
            this.net.init();
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public void trainModel(Dataset dataset) {
        throw new UnsupportedOperationException("Training not implemented");
    }

    public void testModel(Dataset dataset) {
        throw new UnsupportedOperationException("Testing not implemented");
    }

    @Deprecated
    public void tag(Stream<Document> stream) {
        tag((Collection<Document>) stream.collect(Collectors.toList()));
    }

    @Deprecated
    public abstract void tag(Collection<Document> collection);
}
