/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO;
import com.alibaba.cloud.ai.dto.SemanticModelDTO;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.prompt.PromptHelper;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.service.base.BaseSchemaService;
import com.alibaba.cloud.ai.service.business.BusinessKnowledgeRecallService;
import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.dao.DataAccessException;
import reactor.core.publisher.Flux;

public class TableRelationNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(TableRelationNode.class);
    private final BaseSchemaService baseSchemaService;
    private final BaseNl2SqlService baseNl2SqlService;
    private final BusinessKnowledgeRecallService businessKnowledgeRecallService;
    private final SemanticModelRecallService semanticModelRecallService;
    private final DatasourceService datasourceService;

    public TableRelationNode(BaseSchemaService baseSchemaService, BaseNl2SqlService baseNl2SqlService, BusinessKnowledgeRecallService businessKnowledgeRecallService, SemanticModelRecallService semanticModelRecallService, DatasourceService datasourceService) {
        this.baseSchemaService = baseSchemaService;
        this.baseNl2SqlService = baseNl2SqlService;
        this.businessKnowledgeRecallService = businessKnowledgeRecallService;
        this.semanticModelRecallService = semanticModelRecallService;
        this.datasourceService = datasourceService;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        List<SemanticModelDTO> semanticModel;
        List<BusinessKnowledgeDTO> businessKnowledges;
        logger.info("Entering {} node", (Object)this.getClass().getSimpleName());
        int retryCount = StateUtils.getObjectValue(state, "TABLE_RELATION_RETRY_COUNT", Integer.class, 0);
        String input = StateUtils.getStringValue(state, "input");
        List<String> evidenceList = StateUtils.getListValue(state, "EVIDENCES");
        List<Document> tableDocuments = StateUtils.getDocumentList(state, "TABLE_DOCUMENTS_FOR_SCHEMA");
        List<List<Document>> columnDocumentsByKeywords = StateUtils.getDocumentListList(state, "COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT");
        String dataSetId = StateUtils.getStringValue(state, "agentId");
        String agentIdStr = StateUtils.getStringValue(state, "agentId");
        long agentId = -1L;
        if (!agentIdStr.isEmpty()) {
            agentId = Long.parseLong(agentIdStr);
        }
        SchemaDTO schemaDTO = this.buildInitialSchema(columnDocumentsByKeywords, tableDocuments);
        SchemaDTO result = this.processSchemaSelection(schemaDTO, input, evidenceList, state);
        try {
            businessKnowledges = this.businessKnowledgeRecallService.getFieldByDataSetId(dataSetId);
            semanticModel = this.semanticModelRecallService.getFieldByDataSetId(String.valueOf(agentId));
        }
        catch (DataAccessException e) {
            logger.warn("Database query failed (attempt {}): {}", (Object)(retryCount + 1), (Object)e.getMessage());
            String errorType = this.classifyDatabaseError(e);
            return Map.of("TABLE_RELATION_EXCEPTION_OUTPUT", errorType + ": " + e.getMessage(), "TABLE_RELATION_RETRY_COUNT", retryCount + 1);
        }
        String businessKnowledgePrompt = PromptHelper.buildBusinessKnowledgePrompt(businessKnowledges);
        String semanticModelPrompt = PromptHelper.buildSemanticModelPrompt(semanticModel);
        logger.info("[{}] Schema processing result: {}", (Object)this.getClass().getSimpleName(), (Object)result);
        Flux displayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u6784\u5efa\u521d\u59cbSchema..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u521d\u59cbSchema\u6784\u5efa\u5b8c\u6210."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u5904\u7406Schema\u9009\u62e9..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse("Schema\u9009\u62e9\u5904\u7406\u5b8c\u6210."));
            emitter.complete();
        });
        Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> Map.of("TABLE_RELATION_OUTPUT", result, "BUSINESS_KNOWLEDGE", businessKnowledgePrompt, "SEMANTIC_MODEL", semanticModelPrompt), (Flux<ChatResponse>)displayFlux, StreamResponseType.SCHEMA_DEEP_RECALL);
        return Map.of("TABLE_RELATION_OUTPUT", generator, "BUSINESS_KNOWLEDGE", businessKnowledgePrompt, "SEMANTIC_MODEL", semanticModelPrompt, "TABLE_RELATION_RETRY_COUNT", 0, "TABLE_RELATION_EXCEPTION_OUTPUT", "");
    }

    private String classifyDatabaseError(DataAccessException e) {
        String message = e.getMessage();
        if (message != null && (message.contains("timeout") || message.contains("connection") || message.contains("network"))) {
            return "RETRYABLE";
        }
        return "NON_RETRYABLE";
    }

    private SchemaDTO buildInitialSchema(List<List<Document>> columnDocumentsByKeywords, List<Document> tableDocuments) {
        SchemaDTO schemaDTO = new SchemaDTO();
        this.baseSchemaService.extractDatabaseName(schemaDTO);
        this.baseSchemaService.buildSchemaFromDocuments(columnDocumentsByKeywords, tableDocuments, schemaDTO);
        return schemaDTO;
    }

    private DbConfig getAgentDbConfig(OverAllState state) {
        try {
            AgentDatasource activeDatasource;
            String agentIdStr = StateUtils.getStringValue(state, "agentId", null);
            if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
                logger.debug("AgentId is null or empty, will use default dbConfig");
                return null;
            }
            Integer agentId = Integer.valueOf(agentIdStr);
            logger.debug("Getting datasource config for agent: {}", (Object)agentId);
            List<AgentDatasource> agentDatasources = this.datasourceService.getAgentDatasources(agentId);
            if (agentDatasources.isEmpty()) {
                agentDatasources = this.datasourceService.getAgentDatasources(agentId - 999999);
            }
            if ((activeDatasource = (AgentDatasource)agentDatasources.stream().filter(ad -> ad.getIsActive() == 1).findFirst().orElse(null)) == null) {
                logger.debug("Agent {} has no active datasource, will use default dbConfig", (Object)agentId);
                return null;
            }
            DbConfig dbConfig = this.createDbConfigFromDatasource(activeDatasource.getDatasource());
            logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", new Object[]{agentId, dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType()});
            return dbConfig;
        }
        catch (Exception e) {
            logger.warn("Failed to get agent datasource config, will use default dbConfig: {}", (Object)e.getMessage());
            return null;
        }
    }

    private DbConfig createDbConfigFromDatasource(Datasource datasource) {
        DbConfig dbConfig = new DbConfig();
        dbConfig.setUrl(datasource.getConnectionUrl());
        dbConfig.setUsername(datasource.getUsername());
        dbConfig.setPassword(datasource.getPassword());
        if ("mysql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("mysql");
        } else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("postgresql");
        } else {
            throw new RuntimeException("\u4e0d\u652f\u6301\u7684\u6570\u636e\u5e93\u7c7b\u578b: " + datasource.getType());
        }
        dbConfig.setSchema(datasource.getDatabaseName());
        return dbConfig;
    }

    private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List<String> evidenceList, OverAllState state) {
        String schemaAdvice = StateUtils.getStringValue(state, "SQL_GENERATE_SCHEMA_MISSING_ADVICE", null);
        DbConfig agentDbConfig = this.getAgentDbConfig(state);
        logger.debug("Using agent-specific dbConfig: {}", (Object)(agentDbConfig != null ? agentDbConfig.getUrl() : "default"));
        if (schemaAdvice != null) {
            logger.info("[{}] Processing with schema supplement advice: {}", (Object)this.getClass().getSimpleName(), (Object)schemaAdvice);
            return this.baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, schemaAdvice, agentDbConfig);
        }
        logger.info("[{}] Executing regular schema selection", (Object)this.getClass().getSimpleName());
        return this.baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, null, agentDbConfig);
    }
}

