/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO;
import com.alibaba.cloud.ai.dto.SemanticModelDTO;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.prompt.PromptHelper;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.service.base.BaseSchemaService;
import com.alibaba.cloud.ai.service.business.BusinessKnowledgeRecallService;
import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import reactor.core.publisher.Flux;

public class TableRelationNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(TableRelationNode.class);
    private final BaseSchemaService baseSchemaService;
    private final BaseNl2SqlService baseNl2SqlService;
    private final BusinessKnowledgeRecallService businessKnowledgeRecallService;
    private final SemanticModelRecallService semanticModelRecallService;

    public TableRelationNode(BaseSchemaService baseSchemaService, BaseNl2SqlService baseNl2SqlService, BusinessKnowledgeRecallService businessKnowledgeRecallService, SemanticModelRecallService semanticModelRecallService) {
        this.baseSchemaService = baseSchemaService;
        this.baseNl2SqlService = baseNl2SqlService;
        this.businessKnowledgeRecallService = businessKnowledgeRecallService;
        this.semanticModelRecallService = semanticModelRecallService;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        logger.info("Entering {} node", (Object)this.getClass().getSimpleName());
        String input = StateUtils.getStringValue(state, "input");
        List<String> evidenceList = StateUtils.getListValue(state, "EVIDENCES");
        List<Document> tableDocuments = StateUtils.getDocumentList(state, "TABLE_DOCUMENTS_FOR_SCHEMA");
        List<List<Document>> columnDocumentsByKeywords = StateUtils.getDocumentListList(state, "COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT");
        String dataSetId = StateUtils.getStringValue(state, "agentId");
        String agentIdStr = StateUtils.getStringValue(state, "agentId");
        long agentId = -1L;
        if (!agentIdStr.isEmpty()) {
            agentId = Long.parseLong(agentIdStr);
        }
        SchemaDTO schemaDTO = this.buildInitialSchema(columnDocumentsByKeywords, tableDocuments);
        SchemaDTO result = this.processSchemaSelection(schemaDTO, input, evidenceList, state);
        List<BusinessKnowledgeDTO> businessKnowledges = this.businessKnowledgeRecallService.getFieldByDataSetId(dataSetId);
        List<SemanticModelDTO> semanticModel = this.semanticModelRecallService.getFieldByDataSetId(String.valueOf(agentId));
        String businessKnowledgePrompt = PromptHelper.buildBusinessKnowledgePrompt(businessKnowledges);
        String semanticModelPrompt = PromptHelper.buildSemanticModelPrompt(semanticModel);
        logger.info("[{}] Schema processing result: {}", (Object)this.getClass().getSimpleName(), (Object)result);
        Flux displayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u6784\u5efa\u521d\u59cbSchema..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u521d\u59cbSchema\u6784\u5efa\u5b8c\u6210."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u5904\u7406Schema\u9009\u62e9..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("Schema\u9009\u62e9\u5904\u7406\u5b8c\u6210."));
            emitter.complete();
        });
        AsyncGenerator<? extends NodeOutput> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> Map.of("TABLE_RELATION_OUTPUT", result, "BUSINESS_KNOWLEDGE", businessKnowledgePrompt, "SEMANTIC_MODEL", semanticModelPrompt), (Flux<ChatResponse>)displayFlux, StreamResponseType.SCHEMA_DEEP_RECALL);
        return Map.of("TABLE_RELATION_OUTPUT", generator, "BUSINESS_KNOWLEDGE", businessKnowledgePrompt, "SEMANTIC_MODEL", semanticModelPrompt);
    }

    private SchemaDTO buildInitialSchema(List<List<Document>> columnDocumentsByKeywords, List<Document> tableDocuments) {
        SchemaDTO schemaDTO = new SchemaDTO();
        this.baseSchemaService.extractDatabaseName(schemaDTO);
        this.baseSchemaService.buildSchemaFromDocuments(columnDocumentsByKeywords, tableDocuments, schemaDTO);
        return schemaDTO;
    }

    private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List<String> evidenceList, OverAllState state) {
        String schemaAdvice = StateUtils.getStringValue(state, "SQL_GENERATE_SCHEMA_MISSING_ADVICE", null);
        if (schemaAdvice != null) {
            logger.info("[{}] Processing with schema supplement advice: {}", (Object)this.getClass().getSimpleName(), (Object)schemaAdvice);
            return this.baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, schemaAdvice);
        }
        logger.info("[{}] Executing regular schema selection", (Object)this.getClass().getSimpleName());
        return this.baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList);
    }
}

