package fi.evolver.ai.spring.prompt.template;

import com.knuddels.jtokkit.api.EncodingType;
import fi.evolver.ai.spring.Api;
import fi.evolver.ai.spring.JtokkitTokenizer;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.Tokenizer;
import fi.evolver.ai.spring.prompt.template.model.Section;
import fi.evolver.ai.spring.prompt.template.model.SectionProperty;
import fi.evolver.ai.spring.util.DurationUtils;
import freemarker.template.Configuration;
import freemarker.template.Template;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fi/evolver/ai/spring/prompt/template/TemplateUtils.class */
public class TemplateUtils {
    public static final String META_PROPERTY_MODEL = "model";
    private static final Configuration FREEMARKER_CONFIGURATION;
    public static final String SECTION_META = "META";
    public static final String SECTION_COMMENT = "COMMENT";
    public static final String SECTION_FUNCTION = "FUNCTION";
    public static final String SECTION_ASSISTANT_MESSAGE = "ASSISTANT_MESSAGE";
    public static final String SECTION_SYSTEM_MESSAGE = "SYSTEM_MESSAGE";
    public static final String SECTION_USER_MESSAGE = "USER_MESSAGE";
    public static final String SECTION_HISTORY = "HISTORY";
    public static final String SECTION_PROMPT = "PROMPT";
    private static final Logger LOG = LoggerFactory.getLogger(TemplateUtils.class);
    private static final Map<String, Template> TEMPLATE_CACHE = new ConcurrentHashMap();
    private static final Pattern REGEX_MODEL_WITH_CL100K_BASE = Pattern.compile("(?:gpt-3.5-turbo|gpt-4)(?:[-o].*)?");
    public static final HistoryTag TAG_HISTORY = new HistoryTag();

    /* loaded from: input_file:fi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails.class */
    public static final class MetaDetails<T extends Api> extends Record {
        private final Model<T> model;
        private final Optional<Duration> timeout;
        private final Map<String, String> properties;

        public MetaDetails(Model<T> model, Optional<Duration> optional, Map<String, String> map) {
            this.model = model;
            this.timeout = optional;
            this.properties = map;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, MetaDetails.class), MetaDetails.class, "model;timeout;properties", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->model:Lfi/evolver/ai/spring/Model;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->timeout:Ljava/util/Optional;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->properties:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, MetaDetails.class), MetaDetails.class, "model;timeout;properties", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->model:Lfi/evolver/ai/spring/Model;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->timeout:Ljava/util/Optional;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->properties:Ljava/util/Map;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, MetaDetails.class, Object.class), MetaDetails.class, "model;timeout;properties", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->model:Lfi/evolver/ai/spring/Model;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->timeout:Ljava/util/Optional;", "FIELD:Lfi/evolver/ai/spring/prompt/template/TemplateUtils$MetaDetails;->properties:Ljava/util/Map;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Model<T> model() {
            return this.model;
        }

        public Optional<Duration> timeout() {
            return this.timeout;
        }

        public Map<String, String> properties() {
            return this.properties;
        }
    }

    public static Template getTemplate(Section section) throws IOException {
        Optional<String> property = section.getProperty("template");
        return getTemplate(property.isPresent() ? readResource(property.get()) : section.content());
    }

    private static Template getTemplate(String str) throws IOException {
        String sha1Hex = DigestUtils.sha1Hex(str);
        Template template = TEMPLATE_CACHE.get(sha1Hex);
        if (template == null) {
            template = createTemplate(sha1Hex, str);
            TEMPLATE_CACHE.put(sha1Hex, template);
        }
        return template;
    }

    private static Template createTemplate(String str, String str2) throws IOException {
        return new Template(str, str2, FREEMARKER_CONFIGURATION);
    }

    private static String readResource(String str) throws IOException {
        InputStreamReader inputStreamReader = new InputStreamReader(TemplateUtils.class.getResourceAsStream(str), StandardCharsets.UTF_8);
        try {
            String iOUtils = IOUtils.toString(inputStreamReader);
            inputStreamReader.close();
            return iOUtils;
        } catch (Throwable th) {
            try {
                inputStreamReader.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void disableFreemarkerLogging() {
        try {
            freemarker.log.Logger.selectLoggerLibrary(0);
        } catch (ClassNotFoundException e) {
            LOG.warn("Could not disable Freemarker log spam");
        }
    }

    public static <T extends Api> MetaDetails<T> parseMetaDetails(List<Section> list) {
        return parseMetaDetails(list, null);
    }

    public static <T extends Api> MetaDetails<T> parseMetaDetails(List<Section> list, Model<T> model) {
        Map<String, String> metaProperties = getMetaProperties(list);
        String remove = metaProperties.remove("model");
        if (remove == null && model != null) {
            remove = model.name();
        }
        if (remove == null) {
            throw new PromptTemplateException(SECTION_META, "missing the required 'model' property", new Object[0]);
        }
        return new MetaDetails<>(new Model(remove, inferTokenLimit(remove, metaProperties.remove("tokenLimit")), inferTokenizer(remove, metaProperties.remove("tokenizer"))), Optional.ofNullable(metaProperties.remove("timeout")).map((v0) -> {
            return v0.toString();
        }).map(DurationUtils::parseDurationWithUnit), metaProperties);
    }

    public static int inferTokenLimit(String str, String str2) {
        if (str2 != null) {
            return Integer.parseInt(str2);
        }
        if (str.startsWith("gpt-4-turbo") || str.startsWith("gpt-4-1106") || str.startsWith("gpt-4o")) {
            return 128000;
        }
        if (str.startsWith("gpt-4")) {
            return 8192;
        }
        if (!str.startsWith("gpt-3.5-turbo")) {
            if (str.startsWith("text-embedding-ada-002")) {
                return 8192;
            }
            throw new PromptTemplateException(SECTION_META, "missing property token_limit, could not infer the value", str);
        }
        String[] split = str.split("-");
        String str3 = split.length >= 4 ? split[3] : null;
        if (str3 == null) {
            return 4096;
        }
        if ("16k".equals(str3)) {
            return 16385;
        }
        return (!str3.matches("\\d{4}") || "0125".compareTo(str3) < 0) ? 4096 : 16385;
    }

    public static Tokenizer inferTokenizer(String str, String str2) {
        if (str2 != null) {
            return JtokkitTokenizer.of((EncodingType) EncodingType.fromName(str2).orElseThrow(() -> {
                return new PromptTemplateException(SECTION_META, "unsupported tokenizer '%s'", str2);
            }));
        }
        if (REGEX_MODEL_WITH_CL100K_BASE.matcher(str).matches()) {
            return Tokenizer.CL100K_BASE;
        }
        throw new PromptTemplateException(SECTION_META, "missing property tokenizer, could not infer the value", str);
    }

    private static Map<String, String> getMetaProperties(List<Section> list) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Stream<R> map = list.stream().filter(section -> {
            return SECTION_META.equals(section.type());
        }).map((v0) -> {
            return v0.properties();
        });
        Objects.requireNonNull(linkedHashMap);
        map.forEach(linkedHashMap::putAll);
        return linkedHashMap;
    }

    public static List<Section> parseTemplate(Reader reader) {
        try {
            TemplateLineStream templateLineStream = new TemplateLineStream(reader);
            try {
                List<Section> parseTemplate = parseTemplate(templateLineStream);
                templateLineStream.close();
                return parseTemplate;
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException("Failed reading prompt template", e);
        }
    }

    private static List<Section> parseTemplate(TemplateLineStream templateLineStream) throws IOException {
        ArrayList arrayList = new ArrayList();
        while (templateLineStream.hasSectionHeader()) {
            arrayList.add(parseSection(templateLineStream));
        }
        return arrayList;
    }

    private static Section parseSection(TemplateLineStream templateLineStream) throws IOException {
        String expectSectionHeader = templateLineStream.expectSectionHeader();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        while (templateLineStream.hasProperty()) {
            SectionProperty expectProperty = templateLineStream.expectProperty();
            linkedHashMap.put(expectProperty.key(), expectProperty.value());
        }
        StringBuilder sb = new StringBuilder();
        while (templateLineStream.hasNext() && !templateLineStream.hasSectionHeader()) {
            sb.append(templateLineStream.next()).append("\n");
        }
        return new Section(expectSectionHeader, linkedHashMap, sb.toString());
    }

    static {
        disableFreemarkerLogging();
        FREEMARKER_CONFIGURATION = new Configuration(Configuration.VERSION_2_3_31);
        FREEMARKER_CONFIGURATION.setSharedVariable("history", TAG_HISTORY);
    }
}
