package fi.evolver.ai.spring.provider.replicate;

import com.amazonaws.services.sns.model.InvalidParameterException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.Tokenizer;
import fi.evolver.ai.spring.config.ApiConfigurationService;
import fi.evolver.ai.spring.config.ApiEndpointParameters;
import fi.evolver.ai.spring.image.ImageApi;
import fi.evolver.ai.spring.image.ImageResponse;
import fi.evolver.ai.spring.image.prompt.ImageGenerationPrompt;
import fi.evolver.ai.spring.image.prompt.ImageVariationPrompt;
import fi.evolver.ai.spring.prompt.Prompt;
import fi.evolver.ai.spring.provider.replicate.response.RStatus;
import fi.evolver.ai.spring.provider.replicate.response.ReplicateFluxImageResponse;
import fi.evolver.ai.spring.util.Json;
import fi.evolver.basics.spring.http.LoggingHttpClient;
import fi.evolver.basics.spring.log.MessageLogService;
import fi.evolver.utils.ContextUtils;
import jakarta.annotation.PreDestroy;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:fi/evolver/ai/spring/provider/replicate/ReplicateService.class */
public class ReplicateService implements ImageApi {
    private static final Logger LOG = LoggerFactory.getLogger(ReplicateService.class);
    public static final Model<ImageApi> FLUX_DEV = new Model<>("black-forest-labs/flux-dev", Integer.MAX_VALUE, Tokenizer.CL100K_BASE);
    public static final Model<ImageApi> FLUX_PRO = new Model<>("black-forest-labs/flux-pro", Integer.MAX_VALUE, Tokenizer.CL100K_BASE);
    public static final Model<ImageApi> FLUX_SCHNELL = new Model<>("black-forest-labs/flux-schnell", Integer.MAX_VALUE, Tokenizer.CL100K_BASE);
    private static final Map<String, ModelVersionInfo> latestModelVersions = new HashMap();
    private final LoggingHttpClient httpClient;
    private final ApiConfigurationService apiConfigurationService;
    private final int maxModelVersionAge;
    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo.class */
    public static final class ModelVersionInfo extends Record {
        private final String version;
        private final LocalDateTime fetchedAt;

        private ModelVersionInfo(String str, LocalDateTime localDateTime) {
            this.version = str;
            this.fetchedAt = localDateTime;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ModelVersionInfo.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ModelVersionInfo.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ModelVersionInfo.class, Object.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String version() {
            return this.version;
        }

        public LocalDateTime fetchedAt() {
            return this.fetchedAt;
        }
    }

    @Autowired
    public ReplicateService(MessageLogService messageLogService, @Value("${evolver.replicate.connection.timeout.seconds:5}") int i, @Value("${evolver.replicate.max-model-version-age-h:24}") int i2, ApiConfigurationService apiConfigurationService) {
        this.httpClient = new LoggingHttpClient(messageLogService, HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(i)).build());
        this.maxModelVersionAge = i2;
        this.apiConfigurationService = apiConfigurationService;
    }

    @PreDestroy
    private void shutdownScheduler() {
        this.scheduler.shutdown();
    }

    private static Optional<String> getProviderName(Prompt prompt) {
        return prompt.getStringProperty("provider");
    }

    @Override // fi.evolver.ai.spring.image.ImageApi
    public ImageResponse send(ImageGenerationPrompt imageGenerationPrompt) {
        ApiEndpointParameters endpointParameters = this.apiConfigurationService.getEndpointParameters(ReplicateService.class, getProviderName(imageGenerationPrompt), ImageApi.class);
        HttpRequest.Builder POST = HttpRequest.newBuilder(endpointParameters.prepareUri("predictions")).header("Content-Type", "application/json").timeout(imageGenerationPrompt.timeout().orElse(ImageApi.DEFAULT_TIMEOUT)).POST(HttpRequest.BodyPublishers.ofString(ReplicateRequestGenerator.generate(imageGenerationPrompt, getOrFetchModelVersion(endpointParameters, imageGenerationPrompt.model().name()))));
        Map<String, String> headers = endpointParameters.headers();
        Objects.requireNonNull(POST);
        headers.forEach(POST::header);
        HttpRequest build = POST.build();
        if (FLUX_PRO.name().equals(imageGenerationPrompt.model().name())) {
            return sendFluxProImageRequest(build, imageGenerationPrompt);
        }
        if (!FLUX_DEV.name().equals(imageGenerationPrompt.model().name()) && !FLUX_SCHNELL.name().equals(imageGenerationPrompt.model().name())) {
            throw new InvalidParameterException("Unsupported model \"%s\"".formatted(imageGenerationPrompt.model().name()));
        }
        return sendFluxDevShnellImageRequest(build, imageGenerationPrompt);
    }

    private String getOrFetchModelVersion(ApiEndpointParameters apiEndpointParameters, String str) {
        LocalDateTime now = LocalDateTime.now();
        if (latestModelVersions.containsKey(str)) {
            ModelVersionInfo modelVersionInfo = latestModelVersions.get(str);
            if (modelVersionInfo.fetchedAt().isAfter(now.minusHours(this.maxModelVersionAge))) {
                return modelVersionInfo.version;
            }
        }
        String fetchModelVersion = fetchModelVersion(apiEndpointParameters, str);
        latestModelVersions.put(str, new ModelVersionInfo(fetchModelVersion, now));
        return fetchModelVersion;
    }

    private String fetchModelVersion(ApiEndpointParameters apiEndpointParameters, String str) {
        JsonNode jsonNode;
        LoggingHttpClient.LogParameters logParameters = new LoggingHttpClient.LogParameters("ModelInfoRequest");
        try {
            HttpRequest.Builder GET = HttpRequest.newBuilder(apiEndpointParameters.prepareUri("models", str)).timeout(ImageApi.DEFAULT_TIMEOUT).GET();
            Map<String, String> headers = apiEndpointParameters.headers();
            Objects.requireNonNull(GET);
            headers.forEach(GET::header);
            HttpResponse send = this.httpClient.send(GET.build(), HttpResponse.BodyHandlers.ofString(), logParameters);
            if (send.statusCode() < 200 || send.statusCode() >= 300) {
                throw new ApiResponseException("Failed Replicate model info request. HTTP status %d. Response:\n%s", Integer.valueOf(send.statusCode()), send.body());
            }
            String str2 = null;
            JsonNode readTree = Json.OBJECT_MAPPER.readTree((String) send.body());
            if (readTree != null && (jsonNode = readTree.get("latest_version")) != null) {
                str2 = jsonNode.get("id").asText();
            }
            if (str2 == null) {
                throw new ApiResponseException("Failed Replicate model info request. Could not find model version", new Object[0]);
            }
            return str2;
        } catch (Exception e) {
            throw new ApiResponseException(e, "Failed Replicate model info request", new Object[0]);
        }
    }

    private ReplicateFluxImageResponse sendFluxDevShnellImageRequest(HttpRequest httpRequest, ImageGenerationPrompt imageGenerationPrompt) {
        try {
            RStatus<List<String>> pollStatus = pollStatus(sendImageRequest(this.httpClient, httpRequest), httpRequest.headers(), imageGenerationPrompt.timeout().orElse(DEFAULT_TIMEOUT).toMillis());
            if (!pollStatus.isSuccess()) {
                LOG.error(pollStatus.logs());
                if (pollStatus.isCancelled()) {
                    throw new InterruptedException("Request cancelled");
                }
                if (pollStatus.isFailed()) {
                    throw new ApiResponseException(new Exception(pollStatus.error()), "Failed Replicate image request", new Object[0]);
                }
            }
            return new ReplicateFluxImageResponse(imageGenerationPrompt, pollStatus.output().stream().map(ReplicateService::toUrlNoThrow).filter((v0) -> {
                return Objects.nonNull(v0);
            }).toList(), pollStatus.completedAt().toString());
        } catch (InterruptedException | ExecutionException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private RStatus<List<String>> pollStatus(RStatus<List<String>> rStatus, HttpHeaders httpHeaders, long j) throws InterruptedException, ExecutionException {
        RStatus<List<String>> cancelRequest;
        CompletableFuture completableFuture = new CompletableFuture();
        ContextUtils.Context context = ContextUtils.getContext();
        ScheduledFuture<?> scheduleAtFixedRate = this.scheduler.scheduleAtFixedRate(() -> {
            try {
                ContextUtils.ContextCloser ensureContext = ContextUtils.ensureContext(context);
                try {
                    if (completableFuture.isDone()) {
                        if (ensureContext != null) {
                            ensureContext.close();
                        }
                    } else {
                        RStatus fetchRequestStatus = fetchRequestStatus(this.httpClient, rStatus.urls().get(), httpHeaders);
                        if (!fetchRequestStatus.isInProgress()) {
                            completableFuture.complete(fetchRequestStatus);
                        }
                        if (ensureContext != null) {
                            ensureContext.close();
                        }
                    }
                } finally {
                }
            } catch (Exception e) {
                completableFuture.completeExceptionally(e);
            }
        }, 0L, 500L, TimeUnit.MILLISECONDS);
        try {
            try {
                cancelRequest = (RStatus) completableFuture.get(j, TimeUnit.MILLISECONDS);
                scheduleAtFixedRate.cancel(false);
            } catch (TimeoutException e) {
                cancelRequest = cancelRequest(this.httpClient, rStatus.urls().cancel(), httpHeaders);
                scheduleAtFixedRate.cancel(false);
            }
            return cancelRequest;
        } catch (Throwable th) {
            scheduleAtFixedRate.cancel(false);
            throw th;
        }
    }

    private ReplicateFluxImageResponse sendFluxProImageRequest(HttpRequest httpRequest, ImageGenerationPrompt imageGenerationPrompt) {
        try {
            int intValue = imageGenerationPrompt.getIntProperty(ReplicateRequestParameters.NUM_OUTPUTS).orElse(1).intValue();
            HashMap hashMap = new HashMap();
            for (int i = 0; i < intValue; i++) {
                hashMap.put(Integer.valueOf(i), sendImageRequest(this.httpClient, httpRequest));
            }
            pollStatuses(hashMap, httpRequest.headers(), imageGenerationPrompt.timeout().orElse(DEFAULT_TIMEOUT).toMillis());
            hashMap.values().stream().filter((v0) -> {
                return v0.isFailed();
            }).forEach(rStatus -> {
                LOG.error(rStatus.logs());
            });
            if (hashMap.values().stream().noneMatch((v0) -> {
                return v0.isSuccess();
            })) {
                throw new ApiResponseException("Failed Replicate image request", new Object[0]);
            }
            return new ReplicateFluxImageResponse(imageGenerationPrompt, hashMap.values().stream().filter((v0) -> {
                return v0.isSuccess();
            }).map((v0) -> {
                return v0.output();
            }).map(ReplicateService::toUrlNoThrow).filter((v0) -> {
                return Objects.nonNull(v0);
            }).toList(), hashMap.values().stream().map((v0) -> {
                return v0.completedAt();
            }).max((v0, v1) -> {
                return v0.compareTo(v1);
            }).toString());
        } catch (InterruptedException | ExecutionException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private void pollStatuses(Map<Integer, RStatus<String>> map, HttpHeaders httpHeaders, long j) throws InterruptedException, ExecutionException {
        CompletableFuture completableFuture = new CompletableFuture();
        ContextUtils.Context context = ContextUtils.getContext();
        ScheduledFuture<?> scheduleAtFixedRate = this.scheduler.scheduleAtFixedRate(() -> {
            try {
                ContextUtils.ContextCloser ensureContext = ContextUtils.ensureContext(context);
                try {
                    if (completableFuture.isDone()) {
                        if (ensureContext != null) {
                            ensureContext.close();
                            return;
                        }
                        return;
                    }
                    for (Map.Entry entry : map.entrySet().stream().filter(entry2 -> {
                        return ((RStatus) entry2.getValue()).isInProgress();
                    }).toList()) {
                        map.put((Integer) entry.getKey(), fetchRequestStatus(this.httpClient, ((RStatus) entry.getValue()).urls().get(), httpHeaders));
                    }
                    if (map.entrySet().stream().noneMatch(entry3 -> {
                        return ((RStatus) entry3.getValue()).isInProgress();
                    })) {
                        completableFuture.complete(null);
                    }
                    if (ensureContext != null) {
                        ensureContext.close();
                    }
                } finally {
                }
            } catch (Exception e) {
                completableFuture.completeExceptionally(e);
            }
        }, 0L, 500L, TimeUnit.MILLISECONDS);
        try {
            try {
                completableFuture.get(j, TimeUnit.MILLISECONDS);
                scheduleAtFixedRate.cancel(false);
            } catch (TimeoutException e) {
                for (Map.Entry<Integer, RStatus<String>> entry : map.entrySet().stream().filter(entry2 -> {
                    return ((RStatus) entry2.getValue()).isInProgress();
                }).toList()) {
                    map.put(entry.getKey(), cancelRequest(this.httpClient, entry.getValue().urls().cancel(), httpHeaders));
                }
                scheduleAtFixedRate.cancel(false);
            }
        } catch (Throwable th) {
            scheduleAtFixedRate.cancel(false);
            throw th;
        }
    }

    private static <T> RStatus<T> sendImageRequest(LoggingHttpClient loggingHttpClient, HttpRequest httpRequest) {
        try {
            HttpResponse send = loggingHttpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString(), new LoggingHttpClient.LogParameters("ImageGenerationRequest"));
            if (send.statusCode() < 200 || send.statusCode() >= 300) {
                throw new ApiResponseException("Failed Replicate image request. HTTP status %d. Response:\n%s", Integer.valueOf(send.statusCode()), send.body());
            }
            return (RStatus) Json.OBJECT_MAPPER.readValue((String) send.body(), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.1
            });
        } catch (IOException | InterruptedException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private static <T> RStatus<T> fetchRequestStatus(HttpClient httpClient, URI uri, HttpHeaders httpHeaders) {
        try {
            HttpRequest.Builder GET = HttpRequest.newBuilder(uri).GET();
            httpHeaders.map().forEach((str, list) -> {
                list.forEach(str -> {
                    GET.header(str, str);
                });
            });
            HttpResponse send = httpClient.send(GET.build(), HttpResponse.BodyHandlers.ofString());
            if (send.statusCode() < 200 || send.statusCode() >= 300) {
                throw new ApiResponseException("Failed Replicate status request. HTTP status %d. Response:\n%s", Integer.valueOf(send.statusCode()), send.body());
            }
            return (RStatus) Json.OBJECT_MAPPER.readValue((String) send.body(), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.2
            });
        } catch (IOException | InterruptedException e) {
            throw new ApiResponseException(e, "Failed Replicate status request", new Object[0]);
        }
    }

    private static <T> RStatus<T> cancelRequest(LoggingHttpClient loggingHttpClient, URI uri, HttpHeaders httpHeaders) {
        LoggingHttpClient.LogParameters logParameters = new LoggingHttpClient.LogParameters("ImageCancelRequest");
        try {
            HttpRequest.Builder POST = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.noBody());
            httpHeaders.map().forEach((str, list) -> {
                list.forEach(str -> {
                    POST.header(str, str);
                });
            });
            HttpResponse send = loggingHttpClient.send(POST.build(), HttpResponse.BodyHandlers.ofString(), logParameters);
            if (send.statusCode() == 404) {
                return fetchRequestStatus(loggingHttpClient, uri, httpHeaders);
            }
            if (send.statusCode() < 200 || send.statusCode() >= 300) {
                throw new ApiResponseException("Failed Replicate cancellation request. HTTP status %d. Response:\n%s", Integer.valueOf(send.statusCode()), send.body());
            }
            return (RStatus) Json.OBJECT_MAPPER.readValue((String) send.body(), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.3
            });
        } catch (IOException | InterruptedException e) {
            throw new ApiResponseException(e, "Failed Replicate cancellation request", new Object[0]);
        }
    }

    private static URL toUrlNoThrow(String str) {
        try {
            return new URL(str);
        } catch (MalformedURLException e) {
            return null;
        }
    }

    @Override // fi.evolver.ai.spring.image.ImageApi
    public ImageResponse send(ImageVariationPrompt imageVariationPrompt) {
        throw new UnsupportedOperationException("Unsupported method 'send ImageVariationPrompt'");
    }
}
