package cn.langpy.nlp2cron.core;

import com.alibaba.fastjson.JSONObject;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.RawTensor;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;

/* loaded from: input_file:cn/langpy/nlp2cron/core/CrondModel.class */
public class CrondModel {
    static Session model = null;
    static CrondConfig config;

    public static void init(String str) {
        loadModel(str);
    }

    public static void init(CrondConfig crondConfig) {
        config = crondConfig;
        loadModel(config.getModelPath());
    }

    public static void init() {
        loadModel(config.getModelPath());
    }

    private static void loadModel(String str) {
        try {
            model = SavedModelBundle.load(str, new String[]{"serve"}).session();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static String predict(String str) {
        RawTensor asRawTensor = ((Tensor) model.runner().feed("serving_default_input_1:0", toVec(str)).fetch("StatefulPartitionedCall:0").run().get(0)).asRawTensor();
        FloatDataBuffer asFloats = asRawTensor.data().asFloats();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        JSONObject outputId2Word = config.getOutputId2Word();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= asFloats.size()) {
                asRawTensor.close();
                return String.join("#", arrayList2);
            }
            if ((j2 + 1) % 127 != 0) {
                arrayList.add(Float.valueOf(asFloats.getFloat(j2)));
            } else {
                arrayList2.add(outputId2Word.getString(argmax(arrayList) + ""));
                arrayList.clear();
            }
            j = j2 + 1;
        }
    }

    private static int argmax(List<Float> list) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (list.get(i2).floatValue() > d) {
                d = list.get(i2).floatValue();
                i = i2;
            }
        }
        return i;
    }

    private static Tensor toVec(String str) {
        float[][] fArr = new float[1][40];
        JSONObject inputWord2Id = config.getInputWord2Id();
        int i = 0;
        for (char c : str.toCharArray()) {
            if (inputWord2Id.containsKey(c + "")) {
                fArr[0][i] = inputWord2Id.getInteger(r0 + "").intValue();
            } else {
                fArr[0][i] = inputWord2Id.getInteger("<UNK>").intValue();
            }
            i++;
        }
        return TFloat32.tensorOf(StdArrays.ndCopyOf(fArr));
    }

    public static void close() {
        model.close();
    }

    static {
        config = null;
        config = new CrondConfig();
    }
}
