package com.eshore.kg.qa.extract;

import com.eshore.framework.OnEnable;
import com.eshore.framework.StandardComponent;
import com.eshore.framework.StandardProperty;
import com.eshore.tensorflow.BertTokenizer;
import com.eshore.tensorflow.TensorFlowInferenceInterface;
import com.eshore.utils.CallbackException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.IntBuffer;
import java.util.Iterator;
import java.util.function.Consumer;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.Tensor;

@StandardComponent("基于BERT提取摘要")
/* loaded from: input_file:com/eshore/kg/qa/extract/AbstractiveSummarizationModel.class */
public class AbstractiveSummarizationModel {

    @StandardProperty(name = "finalOutputLayer", description = "finalOutputLayer", defaultValue = "strided_slice_5")
    private String finalOutputLayer;

    @StandardProperty(name = "inputIdLayer", description = "inputIdLayer", defaultValue = "Placeholder")
    private String inputIdLayer;

    @StandardProperty(name = "segmentIdLayer", description = "segmentIdLayer", defaultValue = "Placeholder_1")
    private String segmentIdLayer;

    @StandardProperty(name = "文字序列识别模型", description = "文字序列识别模型", required = true)
    private Consumer<Consumer<InputStream>> modelData;

    @StandardProperty(name = "tokenizer of chinese BERT", description = "tokenizer of chinese BERT", required = true)
    private BertTokenizer tokenizer;

    @StandardProperty(name = "输出模型结构", description = "输出模型结构", defaultValue = "false")
    private boolean printGraph;
    private TensorFlowInferenceInterface model;
    private float[] buffer;

    @OnEnable
    private void onEnable() throws IOException {
        try {
            this.modelData.accept(inputStream -> {
                try {
                    this.model = new TensorFlowInferenceInterface(inputStream);
                } catch (Exception e) {
                    throw new CallbackException(e);
                }
            });
            if (this.printGraph) {
                final Graph graph = this.model.getGraph();
                Iterator<Operation> it = new Iterable<Operation>() { // from class: com.eshore.kg.qa.extract.AbstractiveSummarizationModel.1
                    @Override // java.lang.Iterable
                    public Iterator<Operation> iterator() {
                        return graph.operations();
                    }
                }.iterator();
                while (it.hasNext()) {
                    Operation next = it.next();
                    System.out.println("# " + next.name() + " of type: " + next.type());
                    for (int i = 0; i < next.numOutputs(); i++) {
                        Output output = next.output(i);
                        DataType dataType = null;
                        try {
                            dataType = output.dataType();
                        } catch (Exception e) {
                        }
                        System.out.println(dataType + ": " + output.shape().toString());
                    }
                }
            }
        } catch (CallbackException e2) {
            throw ((IOException) e2.getInnerException(IOException.class));
        }
    }

    public String getSummary(String str) {
        int i;
        String normalize = this.tokenizer.normalize(str);
        if (normalize.length() > this.tokenizer.getSequeceLength()) {
            normalize = normalize.substring(0, this.tokenizer.getSequeceLength());
        }
        Tensor<?> processOne = processOne(normalize);
        int[] iArr = new int[TensorFlowInferenceInterface.length(processOne.shape())];
        processOne.writeTo(IntBuffer.wrap(iArr));
        int id = this.tokenizer.getId("[SEP]");
        StringBuilder sb = new StringBuilder();
        int length = iArr.length;
        for (int i2 = 0; i2 < length && (i = iArr[i2]) != id; i2++) {
            if (i >= id) {
                String word = this.tokenizer.getWord(i);
                if (word.startsWith("##")) {
                    sb.append(word.substring(2));
                } else {
                    if (shouldInsertSpace(sb, word)) {
                        sb.append(' ');
                    }
                    sb.append(word);
                }
            }
        }
        return sb.toString();
    }

    private boolean shouldInsertSpace(StringBuilder sb, String str) {
        if (sb.length() == 0 || str.length() == 0) {
            return false;
        }
        int type = Character.getType(sb.charAt(sb.length() - 1));
        if (type != 2 && type != 1 && type != 9) {
            return false;
        }
        char charAt = str.charAt(str.length() - 1);
        return charAt == 2 || charAt == 1 || charAt == '\t';
    }

    private Tensor<?> processOne(String str) {
        this.tokenizer.process(str);
        int length = this.tokenizer.getLength();
        this.model.addFeed(this.inputIdLayer, Tensor.create(new long[]{1, length}, IntBuffer.wrap(this.tokenizer.getInputIdBuffer(), 0, length))).addFeed(this.segmentIdLayer, Tensor.create(new long[]{1, length}, IntBuffer.wrap(this.tokenizer.getSegmentIdBuffer(), 0, length))).run(new String[]{this.finalOutputLayer});
        return this.model.getTensor(this.finalOutputLayer);
    }

    public float[] getBuffer(int i) {
        if (this.buffer == null || this.buffer.length < i) {
            this.buffer = new float[i];
        }
        return this.buffer;
    }
}
