package com.alibaba.cloud.ai.dashscope.image;

import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.common.DashScopeApiConstants;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageMessage;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.backoff.FixedBackOffPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/dashscope/image/DashScopeImageModel.class */
public class DashScopeImageModel implements ImageModel {
    private static final Logger logger = LoggerFactory.getLogger(DashScopeImageModel.class);
    private static final String DEFAULT_MODEL = "wanx-v1";
    private static final int MAX_RETRY_COUNT = 10;
    private final DashScopeImageApi dashScopeImageApi;
    private final DashScopeImageOptions defaultOptions;
    private final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private ImageModelObservationConvention observationConvention;

    /* loaded from: input_file:com/alibaba/cloud/ai/dashscope/image/DashScopeImageModel$Builder.class */
    public static final class Builder {
        private DashScopeImageApi dashScopeImageApi;
        private DashScopeImageOptions defaultOptions = DashScopeImageOptions.builder().withModel(DashScopeImageModel.DEFAULT_MODEL).withN(1).build();
        private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

        private Builder() {
        }

        public Builder dashScopeApi(DashScopeImageApi dashScopeImageApi) {
            this.dashScopeImageApi = dashScopeImageApi;
            return this;
        }

        public Builder defaultOptions(DashScopeImageOptions dashScopeImageOptions) {
            this.defaultOptions = dashScopeImageOptions;
            return this;
        }

        public Builder retryTemplate(RetryTemplate retryTemplate) {
            this.retryTemplate = retryTemplate;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public DashScopeImageModel build() {
            return new DashScopeImageModel(this.dashScopeImageApi, this.defaultOptions, this.retryTemplate, this.observationRegistry);
        }
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions dashScopeImageOptions, RetryTemplate retryTemplate) {
        this(dashScopeImageApi, dashScopeImageOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi) {
        this(dashScopeImageApi, DashScopeImageOptions.builder().withModel(DashScopeImageApi.DEFAULT_IMAGE_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions dashScopeImageOptions) {
        this(dashScopeImageApi, dashScopeImageOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, ObservationRegistry observationRegistry) {
        this(dashScopeImageApi, DashScopeImageOptions.builder().withModel(DashScopeImageApi.DEFAULT_IMAGE_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE, observationRegistry);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions dashScopeImageOptions, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this.observationConvention = new DefaultImageModelObservationConvention();
        Assert.notNull(dashScopeImageApi, "DashScopeImageApi must not be null");
        Assert.notNull(dashScopeImageOptions, "options must not be null");
        Assert.notNull(retryTemplate, "retryTemplate must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        this.dashScopeImageApi = dashScopeImageApi;
        this.defaultOptions = dashScopeImageOptions;
        SimpleRetryPolicy simpleRetryPolicy = new SimpleRetryPolicy(10);
        FixedBackOffPolicy fixedBackOffPolicy = new FixedBackOffPolicy();
        fixedBackOffPolicy.setBackOffPeriod(15000L);
        retryTemplate.setRetryPolicy(simpleRetryPolicy);
        retryTemplate.setBackOffPolicy(fixedBackOffPolicy);
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public static Builder builder() {
        return new Builder();
    }

    public ImageResponse call(ImagePrompt imagePrompt) {
        Assert.notNull(imagePrompt, "Prompt must not be null");
        Assert.isTrue(!CollectionUtils.isEmpty(imagePrompt.getInstructions()), "Prompt messages must not be empty");
        String submitImageGenTask = submitImageGenTask(imagePrompt);
        if (submitImageGenTask == null) {
            return new ImageResponse(List.of(), toMetadataEmpty());
        }
        ImageModelObservationContext build = ImageModelObservationContext.builder().imagePrompt(imagePrompt).provider(DashScopeApiConstants.PROVIDER_NAME).build();
        Observation observation = ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION.observation(this.observationConvention, new DefaultImageModelObservationConvention(), () -> {
            return build;
        }, this.observationRegistry);
        return (ImageResponse) observation.observe(() -> {
            return (ImageResponse) this.retryTemplate.execute(retryContext -> {
                observation.lowCardinalityKeyValue("retry.attempt", String.valueOf(retryContext.getRetryCount()));
                DashScopeImageApi.DashScopeImageAsyncReponse imageGenTask = getImageGenTask(submitImageGenTask);
                if (imageGenTask != null) {
                    String taskStatus = imageGenTask.output().taskStatus();
                    observation.lowCardinalityKeyValue("task.status", taskStatus);
                    boolean z = -1;
                    switch (taskStatus.hashCode()) {
                        case -562638271:
                            if (taskStatus.equals("SUCCEEDED")) {
                                z = false;
                                break;
                            }
                            break;
                        case 433141802:
                            if (taskStatus.equals("UNKNOWN")) {
                                z = 2;
                                break;
                            }
                            break;
                        case 2066319421:
                            if (taskStatus.equals("FAILED")) {
                                z = true;
                                break;
                            }
                            break;
                    }
                    switch (z) {
                        case false:
                            return toImageResponse(imageGenTask);
                        case true:
                        case true:
                            return new ImageResponse(List.of(), toMetadata(imageGenTask));
                    }
                }
                throw new RuntimeException("Image generation still pending");
            }, retryContext2 -> {
                observation.lowCardinalityKeyValue("timeout", "true");
                return new ImageResponse(List.of(), toMetadataTimeout(submitImageGenTask));
            });
        });
    }

    public String submitImageGenTask(ImagePrompt imagePrompt) {
        DashScopeImageOptions imageOptions = toImageOptions(imagePrompt.getOptions());
        logger.debug("Image options: {}", imageOptions);
        ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> submitImageGenTask = this.dashScopeImageApi.submitImageGenTask(constructImageRequest(imagePrompt, imageOptions));
        if (submitImageGenTask != null && submitImageGenTask.getBody() != null) {
            return ((DashScopeImageApi.DashScopeImageAsyncReponse) submitImageGenTask.getBody()).output().taskId();
        }
        logger.warn("Submit imageGen error,request: {}", imagePrompt);
        return null;
    }

    private DashScopeImageOptions toImageOptions(ImageOptions imageOptions) {
        DashScopeImageOptions build = DashScopeImageOptions.builder().withModel(DEFAULT_MODEL).build();
        if (Objects.nonNull(imageOptions)) {
            build = (DashScopeImageOptions) ModelOptionsUtils.copyToTarget(imageOptions, ImageOptions.class, DashScopeImageOptions.class);
        }
        return (DashScopeImageOptions) ModelOptionsUtils.merge(build, this.defaultOptions, DashScopeImageOptions.class);
    }

    public DashScopeImageApi.DashScopeImageAsyncReponse getImageGenTask(String str) {
        ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> imageGenTaskResult = this.dashScopeImageApi.getImageGenTaskResult(str);
        if (imageGenTaskResult != null && imageGenTaskResult.getBody() != null) {
            return (DashScopeImageApi.DashScopeImageAsyncReponse) imageGenTaskResult.getBody();
        }
        logger.warn("No image response returned for taskId: {}", str);
        return null;
    }

    public DashScopeImageOptions getOptions() {
        return this.defaultOptions;
    }

    private ImageResponse toImageResponse(DashScopeImageApi.DashScopeImageAsyncReponse dashScopeImageAsyncReponse) {
        List<DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseResult> results = dashScopeImageAsyncReponse.output().results();
        return new ImageResponse(results == null ? List.of() : results.stream().map(dashScopeImageAsyncReponseResult -> {
            return new ImageGeneration(new Image(dashScopeImageAsyncReponseResult.url(), (String) null));
        }).toList(), toMetadata(dashScopeImageAsyncReponse));
    }

    private DashScopeImageApi.DashScopeImageRequest constructImageRequest(ImagePrompt imagePrompt, DashScopeImageOptions dashScopeImageOptions) {
        return new DashScopeImageApi.DashScopeImageRequest(dashScopeImageOptions.getModel(), new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(((ImageMessage) imagePrompt.getInstructions().get(0)).getText(), dashScopeImageOptions.getNegativePrompt(), dashScopeImageOptions.getRefImg(), dashScopeImageOptions.getFunction(), dashScopeImageOptions.getBaseImageUrl(), dashScopeImageOptions.getMaskImageUrl(), dashScopeImageOptions.getSketchImageUrl()), new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(dashScopeImageOptions.getStyle(), dashScopeImageOptions.getSize(), dashScopeImageOptions.getN(), dashScopeImageOptions.getSeed(), dashScopeImageOptions.getRefStrength(), dashScopeImageOptions.getRefMode(), dashScopeImageOptions.getPromptExtend(), dashScopeImageOptions.getWatermark(), dashScopeImageOptions.getSketchWeight(), dashScopeImageOptions.getSketchExtraction(), dashScopeImageOptions.getSketchColor(), dashScopeImageOptions.getMaskColor()));
    }

    private ImageResponseMetadata toMetadata(DashScopeImageApi.DashScopeImageAsyncReponse dashScopeImageAsyncReponse) {
        DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseOutput output = dashScopeImageAsyncReponse.output();
        DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseTaskMetrics taskMetrics = output.taskMetrics();
        DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseUsage usage = dashScopeImageAsyncReponse.usage();
        ImageResponseMetadata imageResponseMetadata = new ImageResponseMetadata();
        Optional.ofNullable(usage).map((v0) -> {
            return v0.imageCount();
        }).ifPresent(num -> {
            imageResponseMetadata.put("imageCount", num);
        });
        Optional.ofNullable(taskMetrics).ifPresent(dashScopeImageAsyncReponseTaskMetrics -> {
            imageResponseMetadata.put("taskTotal", dashScopeImageAsyncReponseTaskMetrics.total());
            imageResponseMetadata.put("taskSucceeded", dashScopeImageAsyncReponseTaskMetrics.SUCCEEDED());
            imageResponseMetadata.put("taskFailed", dashScopeImageAsyncReponseTaskMetrics.FAILED());
        });
        imageResponseMetadata.put("requestId", dashScopeImageAsyncReponse.requestId());
        imageResponseMetadata.put("taskStatus", output.taskStatus());
        Optional.ofNullable(output.code()).ifPresent(str -> {
            imageResponseMetadata.put(DashScopeApiConstants.CODE, str);
        });
        Optional.ofNullable(output.message()).ifPresent(str2 -> {
            imageResponseMetadata.put(DashScopeApiConstants.MESSAGE, str2);
        });
        return imageResponseMetadata;
    }

    private ImageResponseMetadata toMetadataEmpty() {
        ImageResponseMetadata imageResponseMetadata = new ImageResponseMetadata();
        imageResponseMetadata.put("taskStatus", "NO_TASK_ID");
        return imageResponseMetadata;
    }

    private ImageResponseMetadata toMetadataTimeout(String str) {
        ImageResponseMetadata imageResponseMetadata = new ImageResponseMetadata();
        imageResponseMetadata.put("taskId", str);
        imageResponseMetadata.put("taskStatus", "TIMED_OUT");
        return imageResponseMetadata;
    }

    public void setObservationConvention(ImageModelObservationConvention imageModelObservationConvention) {
        Assert.notNull(imageModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = imageModelObservationConvention;
    }
}
