/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.dto.KeywordExtractionResult;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;

public class KeywordExtractNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(KeywordExtractNode.class);
    private final BaseNl2SqlService baseNl2SqlService;

    public KeywordExtractNode(BaseNl2SqlService baseNl2SqlService) {
        this.baseNl2SqlService = baseNl2SqlService;
    }

    private List<KeywordExtractionResult> processMultipleQuestions(List<String> questions) {
        return questions.parallelStream().map(question -> {
            try {
                List<String> evidences = this.baseNl2SqlService.extractEvidences((String)question);
                List<String> keywords = this.baseNl2SqlService.extractKeywords((String)question, evidences);
                logger.info("\u6210\u529f\u4ece\u95ee\u9898\u53d8\u4f53\u63d0\u53d6\u5173\u952e\u8bcd: \u95ee\u9898=\"{}\", \u5173\u952e\u8bcd={}", question, keywords);
                return new KeywordExtractionResult((String)question, evidences, keywords);
            }
            catch (Exception e) {
                logger.warn("\u4ece\u95ee\u9898\u53d8\u4f53\u63d0\u53d6\u5173\u952e\u8bcd\u5931\u8d25: \u95ee\u9898=\"{}\", \u9519\u8bef={}", question, (Object)e.getMessage());
                return new KeywordExtractionResult((String)question, false);
            }
        }).collect(Collectors.toList());
    }

    private List<String> mergeKeywords(List<KeywordExtractionResult> extractionResults, String originalQuestion) {
        if (extractionResults.isEmpty()) {
            return List.of();
        }
        LinkedHashSet mergedKeywords = new LinkedHashSet();
        extractionResults.stream().filter(result -> result.isSuccessful() && result.getQuestion().equals(originalQuestion)).findFirst().ifPresent(result -> mergedKeywords.addAll(result.getKeywords()));
        extractionResults.stream().filter(result -> result.isSuccessful() && !result.getQuestion().equals(originalQuestion)).forEach(result -> mergedKeywords.addAll(result.getKeywords()));
        return new ArrayList<String>(mergedKeywords);
    }

    private List<String> mergeEvidences(List<KeywordExtractionResult> extractionResults) {
        HashSet mergedEvidences = new HashSet();
        extractionResults.stream().filter(KeywordExtractionResult::isSuccessful).forEach(result -> mergedEvidences.addAll(result.getEvidences()));
        return new ArrayList<String>(mergedEvidences);
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        logger.info("Entering {} node", (Object)this.getClass().getSimpleName());
        String input = StateUtils.getStringValue(state, "QUERY_REWRITE_NODE_OUTPUT", StateUtils.getStringValue(state, "input"));
        try {
            logger.info("\u5f00\u59cb\u589e\u5f3a\u5173\u952e\u8bcd\u63d0\u53d6\u5904\u7406...");
            List<String> expandedQuestions = this.baseNl2SqlService.expandQuestion(input);
            logger.info("\u95ee\u9898\u6269\u5c55\u7ed3\u679c: {}", expandedQuestions);
            List<KeywordExtractionResult> extractionResults = this.processMultipleQuestions(expandedQuestions);
            List<String> mergedKeywords = this.mergeKeywords(extractionResults, input);
            List<String> mergedEvidences = this.mergeEvidences(extractionResults);
            logger.info("[{}] \u589e\u5f3a\u63d0\u53d6\u7ed3\u679c - \u8bc1\u636e: {}, \u5173\u952e\u8bcd: {}", new Object[]{this.getClass().getSimpleName(), mergedEvidences, mergedKeywords});
            Flux<ChatResponse> displayFlux = this.createEnhancedDisplayFlux(extractionResults, mergedKeywords, mergedEvidences);
            Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> Map.of("KEYWORD_EXTRACT_NODE_OUTPUT", mergedKeywords, "EVIDENCES", mergedEvidences, "result", mergedKeywords), displayFlux, StreamResponseType.KEYWORD_EXTRACT);
            return Map.of("KEYWORD_EXTRACT_NODE_OUTPUT", generator);
        }
        catch (Exception e) {
            logger.warn("\u589e\u5f3a\u5173\u952e\u8bcd\u63d0\u53d6\u5931\u8d25\uff0c\u56de\u9000\u5230\u539f\u59cb\u5904\u7406\u65b9\u6cd5: {}", (Object)e.getMessage());
            return this.fallbackToOriginalProcessing(state, input);
        }
    }

    private Flux<ChatResponse> createEnhancedDisplayFlux(List<KeywordExtractionResult> extractionResults, List<String> mergedKeywords, List<String> mergedEvidences) {
        return Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u589e\u5f3a\u5173\u952e\u8bcd\u63d0\u53d6..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u6b63\u5728\u6269\u5c55\u95ee\u9898\u7406\u89e3..."));
            for (KeywordExtractionResult result : extractionResults) {
                if (!result.isSuccessful()) continue;
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5904\u7406\u95ee\u9898\u53d8\u4f53: \"" + result.getQuestion() + "\""));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u63d0\u53d6\u7684\u8bc1\u636e: " + String.join((CharSequence)", ", result.getEvidences())));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u63d0\u53d6\u7684\u5173\u952e\u8bcd: " + String.join((CharSequence)", ", result.getKeywords())));
            }
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5408\u5e76\u591a\u4e2a\u95ee\u9898\u53d8\u4f53\u7684\u7ed3\u679c..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5408\u5e76\u540e\u7684\u8bc1\u636e: " + String.join((CharSequence)", ", mergedEvidences)));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5408\u5e76\u540e\u7684\u5173\u952e\u8bcd: " + String.join((CharSequence)", ", mergedKeywords)));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5173\u952e\u8bcd\u63d0\u53d6\u5b8c\u6210."));
            emitter.complete();
        });
    }

    private Map<String, Object> fallbackToOriginalProcessing(OverAllState state, String input) throws Exception {
        List<String> evidences = this.baseNl2SqlService.extractEvidences(input);
        List<String> keywords = this.baseNl2SqlService.extractKeywords(input, evidences);
        logger.info("[{}] \u539f\u59cb\u63d0\u53d6\u7ed3\u679c - \u8bc1\u636e: {}, \u5173\u952e\u8bcd: {}", new Object[]{this.getClass().getSimpleName(), evidences, keywords});
        Flux displayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u5f00\u59cb\u63d0\u53d6\u5173\u952e\u8bcd..."));
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u6b63\u5728\u63d0\u53d6\u8bc1\u636e..."));
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u63d0\u53d6\u7684\u8bc1\u636e: " + String.join((CharSequence)", ", evidences)));
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u6b63\u5728\u63d0\u53d6\u5173\u952e\u8bcd..."));
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u63d0\u53d6\u7684\u5173\u952e\u8bcd: " + String.join((CharSequence)", ", keywords)));
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u5173\u952e\u8bcd\u63d0\u53d6\u5b8c\u6210."));
            emitter.complete();
        });
        Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> Map.of("KEYWORD_EXTRACT_NODE_OUTPUT", keywords, "EVIDENCES", evidences, "result", keywords), (Flux<ChatResponse>)displayFlux, StreamResponseType.KEYWORD_EXTRACT);
        return Map.of("KEYWORD_EXTRACT_NODE_OUTPUT", generator);
    }
}

