/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.connector.accessor.Accessor;
import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
import com.alibaba.cloud.ai.connector.bo.ResultSetBO;
import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.model.execution.ExecutionStep;
import com.alibaba.cloud.ai.node.AbstractPlanBasedNode;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StepResultUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;

public class SqlExecuteNode
extends AbstractPlanBasedNode {
    private static final Logger logger = LoggerFactory.getLogger(SqlExecuteNode.class);
    private final Accessor dbAccessor;
    private final DatasourceService datasourceService;
    private final DbConfig dbConfig;

    public SqlExecuteNode(Accessor dbAccessor, DatasourceService datasourceService, DbConfig dbConfig) {
        this.dbAccessor = dbAccessor;
        this.datasourceService = datasourceService;
        this.dbConfig = dbConfig;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        this.logNodeEntry();
        ExecutionStep executionStep = this.getCurrentExecutionStep(state);
        Integer currentStep = this.getCurrentStepNumber(state);
        ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
        String sqlQuery = toolParameters.getSqlQuery();
        logger.info("Executing SQL query: {}", (Object)sqlQuery);
        logger.info("Step description: {}", (Object)toolParameters.getDescription());
        DbConfig dbConfig = this.getAgentDbConfig(state);
        return this.executeSqlQuery(state, currentStep, sqlQuery, dbConfig);
    }

    private DbConfig getAgentDbConfig(OverAllState state) {
        try {
            String agentIdStr = StateUtils.getStringValue(state, "agentId");
            if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
                return this.dbConfig;
            }
            Integer agentId = Integer.valueOf(agentIdStr);
            logger.info("Getting datasource config for agent: {}", (Object)agentId);
            List<AgentDatasource> agentDatasources = this.datasourceService.getAgentDatasources(agentId);
            if (agentDatasources.size() == 0) {
                agentDatasources = this.datasourceService.getAgentDatasources(agentId - 999999);
            }
            AgentDatasource activeDatasource = agentDatasources.stream().filter(ad -> ad.getIsActive() == 1).findFirst().orElseThrow(() -> new RuntimeException("\u667a\u80fd\u4f53 " + agentId + " \u672a\u914d\u7f6e\u542f\u7528\u7684\u6570\u636e\u6e90"));
            DbConfig dbConfig = this.createDbConfigFromDatasource(activeDatasource.getDatasource());
            logger.info("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", new Object[]{agentId, dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType()});
            return dbConfig;
        }
        catch (Exception e) {
            logger.error("Failed to get agent datasource config", (Throwable)e);
            throw new RuntimeException("\u83b7\u53d6\u667a\u80fd\u4f53\u6570\u636e\u6e90\u914d\u7f6e\u5931\u8d25: " + e.getMessage(), e);
        }
    }

    private DbConfig createDbConfigFromDatasource(Datasource datasource) {
        DbConfig dbConfig = new DbConfig();
        dbConfig.setUrl(datasource.getConnectionUrl());
        dbConfig.setUsername(datasource.getUsername());
        dbConfig.setPassword(datasource.getPassword());
        if ("mysql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("mysql");
        } else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("postgresql");
        } else if ("h2".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("h2");
        } else {
            throw new RuntimeException("\u4e0d\u652f\u6301\u7684\u6570\u636e\u5e93\u7c7b\u578b: " + datasource.getType());
        }
        dbConfig.setSchema(datasource.getDatabaseName());
        return dbConfig;
    }

    private Map<String, Object> executeSqlQuery(OverAllState state, Integer currentStep, String sqlQuery, DbConfig dbConfig) {
        DbQueryParameter dbQueryParameter = new DbQueryParameter();
        dbQueryParameter.setSql(sqlQuery);
        dbQueryParameter.setSchema(dbConfig.getSchema());
        try {
            ResultSetBO resultSetBO = this.dbAccessor.executeSqlAndReturnObject(dbConfig, dbQueryParameter);
            String jsonStr = resultSetBO.toJsonStr();
            Map existingResults = StateUtils.getObjectValue(state, "SQL_EXECUTE_NODE_OUTPUT", Map.class, new HashMap());
            Map<String, String> updatedResults = StepResultUtils.addStepResult(existingResults, currentStep, jsonStr);
            logger.info("SQL execution successful, result count: {}", (Object)(resultSetBO.getData() != null ? resultSetBO.getData().size() : 0));
            Map<String, List> result = Map.of("SQL_EXECUTE_NODE_OUTPUT", updatedResults, "SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", "", "SQL_RESULT_LIST_MEMORY", resultSetBO.getData());
            Flux displayFlux = Flux.create(emitter -> {
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u6267\u884cSQL..."));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u6267\u884cSQL\u67e5\u8be2"));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("```" + sqlQuery + "```"));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u6267\u884cSQL\u5b8c\u6210"));
                emitter.complete();
            });
            Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> result, (Flux<ChatResponse>)displayFlux, StreamResponseType.EXECUTE_SQL);
            return Map.of("SQL_EXECUTE_NODE_OUTPUT", generator);
        }
        catch (Exception e) {
            String errorMessage = e.getMessage();
            logger.error("SQL execution failed - SQL: [{}] ", (Object)sqlQuery, (Object)e);
            Map<String, String> errorResult = Map.of("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", errorMessage);
            Flux errorDisplayFlux = Flux.create(emitter -> {
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u5f00\u59cb\u6267\u884cSQL..."));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("\u6267\u884cSQL\u67e5\u8be2"));
                emitter.next((Object)ChatResponseUtil.createStatusResponse("SQL\u6267\u884c\u5931\u8d25: " + errorMessage));
                emitter.complete();
            });
            Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> errorResult, (Flux<ChatResponse>)errorDisplayFlux, StreamResponseType.EXECUTE_SQL);
            return Map.of("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", generator);
        }
    }
}

