package de.datexis.encoder.impl;

import com.google.common.hash.Funnels;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.hash.BitArrayBloomFilter;
import de.datexis.hash.BitArrayBloomFilterStrategy;
import de.datexis.model.Document;
import de.datexis.model.Span;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.Collection;
import java.util.Iterator;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/BloomEncoder.class */
public class BloomEncoder extends BagOfWordsEncoder {
    protected BitArrayBloomFilter<CharSequence> bloom;

    public BloomEncoder() {
        this("BLM", 4096, 5);
    }

    public BloomEncoder(int i, int i2) {
        this("BLM", i, i2);
    }

    public BloomEncoder(String str, int i, int i2) {
        super(str);
        this.log = LoggerFactory.getLogger(BloomEncoder.class);
        this.bloom = BitArrayBloomFilter.create(Funnels.stringFunnel(Charset.defaultCharset()), i, i2, (BitArrayBloomFilter.Strategy) new BitArrayBloomFilterStrategy());
    }

    @Override // de.datexis.encoder.impl.BagOfWordsEncoder, de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return "Bloom Filter Encoder";
    }

    @Override // de.datexis.encoder.impl.BagOfWordsEncoder, de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        trainModel(collection, 1, WordHelpers.Language.EN);
    }

    @Override // de.datexis.encoder.impl.BagOfWordsEncoder
    public void trainModel(Collection<Document> collection, int i, WordHelpers.Language language) {
        super.trainModel(collection, i, language);
        Iterator<String> it = getWords().iterator();
        while (it.hasNext()) {
            this.bloom.put(it.next());
        }
        appendTrainLog("trained Bloom filter over " + this.vocab.numWords() + " words into " + this.bloom.bitSize() + " bits (ratio: " + (this.bloom.bitSize() / this.vocab.numWords()));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder, de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.bloom.bitSize();
    }

    @Override // de.datexis.encoder.impl.BagOfWordsEncoder, de.datexis.encoder.Encoder, de.datexis.encoder.IEncoder
    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        Iterator<? extends Span> it = iterable.iterator();
        while (it.hasNext()) {
            zeros.addi(Nd4j.create(this.bloom.getBitArray(preprocessor.preProcess(it.next().getText()))).transposei());
        }
        return zeros;
    }

    @Override // de.datexis.encoder.impl.BagOfWordsEncoder
    public INDArray encode(String[] strArr) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        for (String str : strArr) {
            zeros.addi(Nd4j.create(this.bloom.getBitArray(preprocessor.preProcess(str))).transposei());
        }
        return zeros;
    }

    private static void writeEntry(InputStream inputStream, ZipOutputStream zipOutputStream) throws IOException {
        byte[] bArr = new byte[1024];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return;
            } else {
                zipOutputStream.write(bArr, 0, read);
            }
        }
    }

    @Override // de.datexis.encoder.LookupCacheEncoder, de.datexis.annotator.IComponent
    public void saveModel(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".zip");
        try {
            OutputStream outputStream = resolve.getOutputStream();
            Throwable th = null;
            try {
                try {
                    Resource createTempDirectory = Resource.createTempDirectory();
                    ZipOutputStream zipOutputStream = new ZipOutputStream(new BufferedOutputStream(new CloseShieldOutputStream(outputStream)));
                    zipOutputStream.putNextEntry(new ZipEntry("vocab.tsv"));
                    super.saveModel(createTempDirectory, "vocab");
                    BufferedInputStream bufferedInputStream = new BufferedInputStream(createTempDirectory.resolve("vocab.tsv.gz").getInputStream());
                    writeEntry(bufferedInputStream, zipOutputStream);
                    bufferedInputStream.close();
                    zipOutputStream.putNextEntry(new ZipEntry("bloom.bin"));
                    OutputStream outputStream2 = createTempDirectory.resolve("bloom.bin").getOutputStream();
                    this.bloom.writeTo(outputStream2);
                    outputStream2.flush();
                    outputStream2.close();
                    BufferedInputStream bufferedInputStream2 = new BufferedInputStream(createTempDirectory.resolve("bloom.bin").getInputStream());
                    writeEntry(bufferedInputStream2, zipOutputStream);
                    bufferedInputStream2.close();
                    zipOutputStream.flush();
                    zipOutputStream.close();
                    setModel(resolve);
                    setModelAvailable(true);
                    this.log.info("saved bloom filter");
                    if (outputStream != null) {
                        if (0 != 0) {
                            try {
                                outputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            this.log.error(e.toString());
        }
    }

    @Override // de.datexis.encoder.LookupCacheEncoder, de.datexis.annotator.IComponent
    public void loadModel(Resource resource) {
        try {
            Resource createTempDirectory = Resource.createTempDirectory();
            ZipFile zipFile = new ZipFile(resource.toFile());
            Files.copy(zipFile.getInputStream(zipFile.getEntry("vocab.tsv")), createTempDirectory.resolve("vocab.tsv").getPath(), StandardCopyOption.REPLACE_EXISTING);
            super.loadModel(createTempDirectory.resolve("vocab.tsv"));
            Files.copy(zipFile.getInputStream(zipFile.getEntry("bloom.bin")), createTempDirectory.resolve("bloom.bin").getPath(), StandardCopyOption.REPLACE_EXISTING);
            this.bloom = BitArrayBloomFilter.readFrom(createTempDirectory.resolve("bloom.bin").getInputStream(), Funnels.stringFunnel(Charset.defaultCharset()), new BitArrayBloomFilterStrategy());
            setModel(resource);
            setModelAvailable(true);
            this.log.info("loaded bloom filter with size " + getEmbeddingVectorSize());
        } catch (IOException e) {
            this.log.error(e.toString());
        }
    }
}
