package org.springframework.ai.rag.preretrieval.query.expansion;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander.class */
public final class MultiQueryExpander implements QueryExpander {
    private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");
    private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;
    private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;
    private final ChatClient chatClient;
    private final PromptTemplate promptTemplate;
    private final boolean includeOriginal;
    private final int numberOfQueries;

    /* loaded from: input_file:org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander$Builder.class */
    public static final class Builder {
        private ChatClient.Builder chatClientBuilder;
        private PromptTemplate promptTemplate;
        private Boolean includeOriginal;
        private Integer numberOfQueries;

        private Builder() {
        }

        public Builder chatClientBuilder(ChatClient.Builder builder) {
            this.chatClientBuilder = builder;
            return this;
        }

        public Builder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public Builder includeOriginal(Boolean bool) {
            this.includeOriginal = bool;
            return this;
        }

        public Builder numberOfQueries(Integer num) {
            this.numberOfQueries = num;
            return this;
        }

        public MultiQueryExpander build() {
            return new MultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal, this.numberOfQueries);
        }
    }

    public MultiQueryExpander(ChatClient.Builder builder, @Nullable PromptTemplate promptTemplate, @Nullable Boolean bool, @Nullable Integer num) {
        Assert.notNull(builder, "chatClientBuilder cannot be null");
        this.chatClient = builder.build();
        this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
        this.includeOriginal = (bool != null ? bool : DEFAULT_INCLUDE_ORIGINAL).booleanValue();
        this.numberOfQueries = (num != null ? num : DEFAULT_NUMBER_OF_QUERIES).intValue();
        PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
    }

    @Override // org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander
    public List<Query> expand(Query query) {
        Assert.notNull(query, "query cannot be null");
        logger.debug("Generating {} query variants", Integer.valueOf(this.numberOfQueries));
        String content = this.chatClient.prompt().user(promptUserSpec -> {
            promptUserSpec.text(this.promptTemplate.getTemplate()).param("number", Integer.valueOf(this.numberOfQueries)).param("query", query.text());
        }).call().content();
        if (content == null) {
            logger.warn("Query expansion result is null. Returning the input query unchanged.");
            return List.of(query);
        }
        List asList = Arrays.asList(content.split("\n"));
        if (CollectionUtils.isEmpty(asList) || this.numberOfQueries != asList.size()) {
            logger.warn("Query expansion result does not contain the requested {} variants. Returning the input query unchanged.", Integer.valueOf(this.numberOfQueries));
            return List.of(query);
        }
        List<Query> list = (List) asList.stream().filter(StringUtils::hasText).map(str -> {
            return query.mutate().text(str).build();
        }).collect(Collectors.toList());
        if (this.includeOriginal) {
            logger.debug("Including the original query in the result");
            list.add(0, query);
        }
        return list;
    }

    public static Builder builder() {
        return new Builder();
    }
}
