package cn.langpy.nlp2cron.core;

import java.util.List;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:cn/langpy/nlp2cron/core/CrondModel.class */
public class CrondModel {
    static SavedModelBundle encoder;
    static SavedModelBundle decoder;
    static CrondConfig config;

    public static void init(String str) {
        loadModel(str);
    }

    public static void init(CrondConfig crondConfig) {
        config = crondConfig;
        loadModel(config.getModelPath());
    }

    public static void init() {
        loadModel(config.getModelPath());
    }

    private static void loadModel(String str) {
        try {
            encoder = SavedModelBundle.load(str + "/encoder_model", new String[]{"serve"});
            decoder = SavedModelBundle.load(str + "/decoder_model", new String[]{"serve"});
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static String predict(String str) {
        List<Tensor<?>> encoderPredict = encoderPredict(encoder.session(), str);
        Tensor<?> tensor = encoderPredict.get(0);
        Tensor<?> tensor2 = encoderPredict.get(1);
        float[][][] fArr = new float[1][1][17];
        fArr[0][0][CrondConfig.str2id.get("S").intValue()] = 1.0f;
        Tensor create = Tensor.create(fArr);
        Session session = decoder.session();
        List<Tensor<?>> decoderPredict = decoderPredict(session, create, tensor, tensor2);
        Tensor<?> tensor3 = decoderPredict.get(0);
        StringBuffer stringBuffer = new StringBuffer();
        String argmax = argmax(((float[][][]) tensor3.copyTo(new float[1][1][17]))[0]);
        stringBuffer.append(argmax);
        Tensor<?> tensor4 = decoderPredict.get(1);
        Tensor<?> tensor5 = decoderPredict.get(2);
        while (!"E".equals(argmax)) {
            List<Tensor<?>> decoderPredict2 = decoderPredict(session, tensor3, tensor4, tensor5);
            tensor3 = decoderPredict2.get(0);
            argmax = argmax(((float[][][]) tensor3.copyTo(new float[1][1][17]))[0]);
            tensor4 = decoderPredict2.get(1);
            tensor5 = decoderPredict2.get(2);
            if (!"E".equals(argmax) && argmax.length() > 0) {
                stringBuffer.append(argmax);
            }
        }
        tensor3.close();
        tensor4.close();
        tensor5.close();
        return stringBuffer.toString();
    }

    public static List<Tensor<?>> encoderPredict(Session session, String str) {
        Tensor vec = toVec(str, new float[1][15][334]);
        List<Tensor<?>> run = session.runner().feed("serving_default_input_1:0", vec).fetch("StatefulPartitionedCall:0").fetch("StatefulPartitionedCall:1").run();
        vec.close();
        return run;
    }

    public static List<Tensor<?>> decoderPredict(Session session, Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return session.runner().feed("serving_default_input_2:0", tensor).feed("serving_default_input_3:0", tensor2).feed("serving_default_input_4:0", tensor3).fetch("StatefulPartitionedCall:0").fetch("StatefulPartitionedCall:1").fetch("StatefulPartitionedCall:2").run();
    }

    private static String argmax(float[][] fArr) {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (true) {
            if (i >= fArr.length) {
                break;
            }
            double d = 0.0d;
            int i2 = 0;
            for (int i3 = 0; i3 < fArr[i].length; i3++) {
                if (fArr[i][i3] > d) {
                    d = fArr[i][i3];
                    i2 = i3;
                }
            }
            if (i2 == 0) {
                sb.append("#");
                break;
            }
            sb.append(CrondConfig.id2str.get(Integer.valueOf(i2)));
            i++;
        }
        return sb.toString();
    }

    private static Tensor toVec(String str, float[][][] fArr) {
        for (int i = 0; i < fArr[0].length; i++) {
            if (i < str.length()) {
                String substring = str.substring(i, i + 1);
                fArr[0][i][config.getWord2id().containsKey(substring) ? Integer.parseInt(config.getWord2id().get(substring) + "") : Integer.parseInt(config.getWord2id().get("<UNK>") + "")] = 1.0f;
            }
        }
        return Tensor.create(fArr);
    }

    public static void close() {
        encoder.close();
        decoder.close();
    }

    static {
        encoder = null;
        decoder = null;
        config = null;
        try {
            config = new CrondConfig();
            encoder = SavedModelBundle.load("model\\encoder_model", new String[]{"serve"});
            decoder = SavedModelBundle.load("model\\decoder_model", new String[]{"serve"});
        } catch (Exception e) {
        }
    }
}
