package org.springframework.ai.chat.client.advisor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/* loaded from: input_file:org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.class */
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
    public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
    public static final String FILTER_EXPRESSION = "qa_filter_expression";
    private static final String DEFAULT_USER_TEXT_ADVISE = "\nContext information is below, surrounded by ---------------------\n\n---------------------\n{question_answer_context}\n---------------------\n\nGiven the context and provided history information and not prior knowledge,\nreply to the user comment. If the answer is not in the context, inform\nthe user that you can't answer the question.\n";
    private static final int DEFAULT_ORDER = 0;
    private final VectorStore vectorStore;
    private final String userTextAdvise;
    private final SearchRequest searchRequest;
    private final boolean protectFromBlocking;
    private final int order;

    /* loaded from: input_file:org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor$Builder.class */
    public static final class Builder {
        private final VectorStore vectorStore;
        private SearchRequest searchRequest = SearchRequest.builder().build();
        private String userTextAdvise = QuestionAnswerAdvisor.DEFAULT_USER_TEXT_ADVISE;
        private boolean protectFromBlocking = true;
        private int order = 0;

        private Builder(VectorStore vectorStore) {
            Assert.notNull(vectorStore, "The vectorStore must not be null!");
            this.vectorStore = vectorStore;
        }

        public Builder searchRequest(SearchRequest searchRequest) {
            Assert.notNull(searchRequest, "The searchRequest must not be null!");
            this.searchRequest = searchRequest;
            return this;
        }

        public Builder userTextAdvise(String str) {
            Assert.hasText(str, "The userTextAdvise must not be empty!");
            this.userTextAdvise = str;
            return this;
        }

        public Builder protectFromBlocking(boolean z) {
            this.protectFromBlocking = z;
            return this;
        }

        public Builder order(int i) {
            this.order = i;
            return this;
        }

        @Deprecated(forRemoval = true, since = "1.0.0-M5")
        public Builder withSearchRequest(SearchRequest searchRequest) {
            Assert.notNull(searchRequest, "The searchRequest must not be null!");
            this.searchRequest = searchRequest;
            return this;
        }

        @Deprecated(forRemoval = true, since = "1.0.0-M5")
        public Builder withUserTextAdvise(String str) {
            Assert.hasText(str, "The userTextAdvise must not be empty!");
            this.userTextAdvise = str;
            return this;
        }

        @Deprecated(forRemoval = true, since = "1.0.0-M5")
        public Builder withProtectFromBlocking(boolean z) {
            this.protectFromBlocking = z;
            return this;
        }

        @Deprecated(forRemoval = true, since = "1.0.0-M5")
        public Builder withOrder(int i) {
            this.order = i;
            return this;
        }

        public QuestionAnswerAdvisor build() {
            return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise, this.protectFromBlocking, this.order);
        }
    }

    public QuestionAnswerAdvisor(VectorStore vectorStore) {
        this(vectorStore, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE);
    }

    public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
        this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
    }

    public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String str) {
        this(vectorStore, searchRequest, str, true);
    }

    public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String str, boolean z) {
        this(vectorStore, searchRequest, str, z, 0);
    }

    public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String str, boolean z, int i) {
        Assert.notNull(vectorStore, "The vectorStore must not be null!");
        Assert.notNull(searchRequest, "The searchRequest must not be null!");
        Assert.hasText(str, "The userTextAdvise must not be empty!");
        this.vectorStore = vectorStore;
        this.searchRequest = searchRequest;
        this.userTextAdvise = str;
        this.protectFromBlocking = z;
        this.order = i;
    }

    public static Builder builder(VectorStore vectorStore) {
        return new Builder(vectorStore);
    }

    @Override // org.springframework.ai.chat.client.advisor.api.Advisor
    public String getName() {
        return getClass().getSimpleName();
    }

    public int getOrder() {
        return this.order;
    }

    @Override // org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor
    public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain callAroundAdvisorChain) {
        return after(callAroundAdvisorChain.nextAroundCall(before(advisedRequest)));
    }

    @Override // org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor
    public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain streamAroundAdvisorChain) {
        return (this.protectFromBlocking ? Mono.just(advisedRequest).publishOn(Schedulers.boundedElastic()).map(this::before).flatMapMany(advisedRequest2 -> {
            return streamAroundAdvisorChain.nextAroundStream(advisedRequest2);
        }) : streamAroundAdvisorChain.nextAroundStream(before(advisedRequest))).map(advisedResponse -> {
            if (onFinishReason().test(advisedResponse)) {
                advisedResponse = after(advisedResponse);
            }
            return advisedResponse;
        });
    }

    private AdvisedRequest before(AdvisedRequest advisedRequest) {
        HashMap hashMap = new HashMap(advisedRequest.adviseContext());
        String str = advisedRequest.userText() + System.lineSeparator() + this.userTextAdvise;
        List<Document> similaritySearch = this.vectorStore.similaritySearch(SearchRequest.from(this.searchRequest).query(new PromptTemplate(advisedRequest.userText(), advisedRequest.userParams()).render()).filterExpression(doGetFilterExpression(hashMap)).build());
        hashMap.put(RETRIEVED_DOCUMENTS, similaritySearch);
        String str2 = (String) similaritySearch.stream().map((v0) -> {
            return v0.getText();
        }).collect(Collectors.joining(System.lineSeparator()));
        HashMap hashMap2 = new HashMap(advisedRequest.userParams());
        hashMap2.put("question_answer_context", str2);
        return AdvisedRequest.from(advisedRequest).userText(str).userParams(hashMap2).adviseContext(hashMap).build();
    }

    private AdvisedResponse after(AdvisedResponse advisedResponse) {
        ChatResponse.Builder from = ChatResponse.builder().from(advisedResponse.response());
        from.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
        return new AdvisedResponse(from.build(), advisedResponse.adviseContext());
    }

    protected Filter.Expression doGetFilterExpression(Map<String, Object> map) {
        return (map.containsKey(FILTER_EXPRESSION) && StringUtils.hasText(map.get(FILTER_EXPRESSION).toString())) ? new FilterExpressionTextParser().parse(map.get(FILTER_EXPRESSION).toString()) : this.searchRequest.getFilterExpression();
    }

    private Predicate<AdvisedResponse> onFinishReason() {
        return advisedResponse -> {
            return advisedResponse.response().getResults().stream().filter(generation -> {
                return (generation == null || generation.getMetadata() == null || !StringUtils.hasText(generation.getMetadata().getFinishReason())) ? false : true;
            }).findFirst().isPresent();
        };
    }
}
