/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.model.execution.ExecutionStep;
import com.alibaba.cloud.ai.model.execution.Plan;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.MarkdownParser;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.core.ParameterizedTypeReference;
import reactor.core.publisher.Flux;

public class SqlGenerateNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(SqlGenerateNode.class);
    private static final int MAX_RETRY_COUNT = 3;
    private static final int MAX_OPTIMIZATION_ROUNDS = 3;
    private final BaseNl2SqlService baseNl2SqlService;
    private final BeanOutputConverter<Plan> converter;
    private final ChatClient chatClient;

    public SqlGenerateNode(ChatClient.Builder chatClientBuilder, BaseNl2SqlService baseNl2SqlService) {
        this.chatClient = chatClientBuilder.build();
        this.baseNl2SqlService = baseNl2SqlService;
        this.converter = new BeanOutputConverter((ParameterizedTypeReference)new ParameterizedTypeReference<Plan>(){});
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        Map<String, String> result;
        String displayMessage;
        logger.info("Entering {} node", (Object)this.getClass().getSimpleName());
        String plannerNodeOutput = StateUtils.getStringValue(state, "PLANNER_NODE_OUTPUT");
        Plan plan = (Plan)this.converter.convert(plannerNodeOutput);
        Integer currentStep = StateUtils.getObjectValue(state, "PLAN_CURRENT_STEP", Integer.class, 1);
        List<ExecutionStep> executionPlan = plan.getExecutionPlan();
        ExecutionStep executionStep = executionPlan.get(currentStep - 1);
        ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
        if (StateUtils.hasValue(state, "SQL_EXECUTE_NODE_EXCEPTION_OUTPUT")) {
            displayMessage = "\u68c0\u6d4b\u5230SQL\u6267\u884c\u5f02\u5e38\uff0c\u5f00\u59cb\u91cd\u65b0\u751f\u6210SQL...";
            newSql = this.handleSqlExecutionException(state, plan, toolParameters);
            toolParameters.setSqlQuery(newSql);
            result = Map.of("SQL_GENERATE_OUTPUT", "SQL_EXECUTE_NODE", "PLANNER_NODE_OUTPUT", plan.toJsonStr());
            logger.info("[{}] Regenerated SQL due to execution exception: {}", (Object)this.getClass().getSimpleName(), (Object)newSql);
        } else if (this.isSemanticConsistencyFailed(state)) {
            displayMessage = "\u8bed\u4e49\u4e00\u81f4\u6027\u6821\u9a8c\u672a\u901a\u8fc7\uff0c\u5f00\u59cb\u91cd\u65b0\u751f\u6210SQL...";
            newSql = this.handleSemanticConsistencyFailure(state, toolParameters);
            result = Map.of("SQL_GENERATE_OUTPUT", newSql, "result", newSql);
            logger.info("[{}] Regenerated SQL due to semantic consistency failure: {}", (Object)this.getClass().getSimpleName(), (Object)newSql);
        } else {
            throw new IllegalStateException("SQL generation node was called unexpectedly");
        }
        Flux displayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createCustomStatusResponse(displayMessage));
            if (result.containsKey("result")) {
                emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("\u91cd\u65b0\u751f\u6210\u7684SQL: " + result.get("result")));
            } else if (result.containsKey("SQL_GENERATE_OUTPUT") && result.get("SQL_GENERATE_OUTPUT").equals("SQL_EXECUTE_NODE")) {
                emitter.next((Object)ChatResponseUtil.createCustomStatusResponse("SQL\u91cd\u65b0\u751f\u6210\u5b8c\u6210\uff0c\u51c6\u5907\u6267\u884c"));
            }
            emitter.complete();
        });
        Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> result, (Flux<ChatResponse>)displayFlux);
        return Map.of("SQL_GENERATE_OUTPUT", generator);
    }

    private String handleSqlExecutionException(OverAllState state, Plan plan, ExecutionStep.ToolParameters toolParameters) throws Exception {
        String sqlException = StateUtils.getStringValue(state, "SQL_EXECUTE_NODE_EXCEPTION_OUTPUT");
        logger.info("Detected SQL execution exception, starting to regenerate SQL: {}", (Object)sqlException);
        List<String> evidenceList = StateUtils.getListValue(state, "EVIDENCES");
        SchemaDTO schemaDTO = StateUtils.getObjectValue(state, "TABLE_RELATION_OUTPUT", SchemaDTO.class);
        return this.regenerateSql(state, toolParameters.toJsonStr(), evidenceList, schemaDTO, "SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", toolParameters.getSqlQuery());
    }

    private String handleSemanticConsistencyFailure(OverAllState state, ExecutionStep.ToolParameters toolParameters) throws Exception {
        logger.info("Semantic consistency validation failed, starting to regenerate SQL");
        List<String> evidenceList = StateUtils.getListValue(state, "EVIDENCES");
        SchemaDTO schemaDTO = StateUtils.getObjectValue(state, "TABLE_RELATION_OUTPUT", SchemaDTO.class);
        return this.regenerateSql(state, toolParameters.toJsonStr(), evidenceList, schemaDTO, "SEMANTIC_CONSISTENCY_NODE_RECOMMEND_OUTPUT", toolParameters.getSqlQuery());
    }

    private boolean isSemanticConsistencyFailed(OverAllState state) {
        return StateUtils.getObjectValue(state, "SEMANTIC_CONSISTENCY_NODE_OUTPUT", Boolean.class, true) == false;
    }

    private String regenerateSql(OverAllState state, String input, List<String> evidenceList, SchemaDTO schemaDTO, String exceptionOutputKey, String originalSql) throws Exception {
        String exceptionMessage = StateUtils.getStringValue(state, exceptionOutputKey);
        logger.info("\u5f00\u59cb\u589e\u5f3aSQL\u751f\u6210\u6d41\u7a0b - \u539f\u59cbSQL: {}, \u5f02\u5e38\u4fe1\u606f: {}", (Object)originalSql, (Object)exceptionMessage);
        String bestSql = originalSql;
        double bestScore = 0.0;
        for (int round = 1; round <= 3; ++round) {
            logger.info("\u5f00\u59cb\u7b2c{}\u8f6eSQL\u4f18\u5316", (Object)round);
            try {
                String currentSql = round == 1 ? this.baseNl2SqlService.generateSql(evidenceList, input, schemaDTO, originalSql, exceptionMessage) : this.generateOptimizedSql(bestSql, exceptionMessage, round);
                if (currentSql == null || currentSql.trim().isEmpty()) {
                    logger.warn("\u7b2c{}\u8f6eSQL\u751f\u6210\u7ed3\u679c\u4e3a\u7a7a\uff0c\u8df3\u8fc7", (Object)round);
                    continue;
                }
                SqlQualityScore score = this.evaluateSqlQuality(currentSql, schemaDTO);
                logger.info("\u7b2c{}\u8f6eSQL\u8bc4\u5206: \u8bed\u6cd5={}, \u5b89\u5168={}, \u6027\u80fd={}, \u603b\u5206={}", new Object[]{round, score.syntaxScore, score.securityScore, score.performanceScore, score.totalScore});
                if (score.totalScore > bestScore) {
                    bestSql = currentSql;
                    bestScore = score.totalScore;
                    logger.info("\u7b2c{}\u8f6e\u4ea7\u751f\u4e86\u66f4\u597d\u7684SQL\uff0c\u603b\u5206\u63d0\u5347\u5230{}", (Object)round, (Object)score.totalScore);
                }
                if (!(score.totalScore >= 0.95)) continue;
                logger.info("SQL\u8d28\u91cf\u5206\u6570\u8fbe\u5230{}\uff0c\u63d0\u524d\u7ed3\u675f\u4f18\u5316", (Object)score.totalScore);
                break;
            }
            catch (Exception e) {
                logger.warn("\u7b2c{}\u8f6eSQL\u4f18\u5316\u5931\u8d25: {}", (Object)round, (Object)e.getMessage());
            }
        }
        bestSql = this.performFinalValidation(bestSql);
        logger.info("\u589e\u5f3aSQL\u751f\u6210\u5b8c\u6210\uff0c\u6700\u7ec8SQL: {}, \u6700\u7ec8\u8bc4\u5206: {}", (Object)bestSql, (Object)bestScore);
        return bestSql;
    }

    private String generateOptimizedSql(String previousSql, String exceptionMessage, int round) {
        try {
            StringBuilder prompt = new StringBuilder();
            prompt.append("\u8bf7\u5bf9\u4ee5\u4e0bSQL\u8fdb\u884c\u7b2c").append(round).append("\u8f6e\u4f18\u5316:\n\n");
            prompt.append("\u5f53\u524dSQL:\n").append(previousSql).append("\n\n");
            if (exceptionMessage != null && !exceptionMessage.trim().isEmpty()) {
                prompt.append("\u9700\u8981\u89e3\u51b3\u7684\u95ee\u9898:\n").append(exceptionMessage).append("\n\n");
            }
            prompt.append("\u4f18\u5316\u76ee\u6807:\n");
            prompt.append("1. \u4fee\u590d\u4efb\u4f55\u8bed\u6cd5\u9519\u8bef\n");
            prompt.append("2. \u63d0\u5347\u67e5\u8be2\u6027\u80fd\n");
            prompt.append("3. \u786e\u4fdd\u67e5\u8be2\u5b89\u5168\u6027\n");
            prompt.append("4. \u4f18\u5316\u53ef\u8bfb\u6027\n\n");
            prompt.append("\u8bf7\u53ea\u8fd4\u56de\u4f18\u5316\u540e\u7684SQL\u8bed\u53e5\uff0c\u4e0d\u8981\u5305\u542b\u5176\u4ed6\u8bf4\u660e\u3002");
            String response = this.chatClient.prompt().user(prompt.toString()).call().content();
            return MarkdownParser.extractRawText(response).trim();
        }
        catch (Exception e) {
            logger.error("\u4f7f\u7528ChatClient\u4f18\u5316SQL\u5931\u8d25: {}", (Object)e.getMessage());
            return previousSql;
        }
    }

    private SqlQualityScore evaluateSqlQuality(String sql, SchemaDTO schemaDTO) {
        SqlQualityScore score = new SqlQualityScore();
        score.syntaxScore = this.validateSqlSyntax(sql);
        score.securityScore = this.validateSqlSecurity(sql);
        score.performanceScore = this.evaluateSqlPerformance(sql);
        score.totalScore = score.syntaxScore * 0.4 + score.securityScore * 0.3 + score.performanceScore * 0.3;
        return score;
    }

    private double validateSqlSyntax(String sql) {
        long singleQuotes;
        long closeParens;
        long openParens;
        if (sql == null || sql.trim().isEmpty()) {
            return 0.0;
        }
        double score = 1.0;
        String upperSql = sql.toUpperCase();
        if (!upperSql.contains("SELECT")) {
            score -= 0.3;
        }
        if (!upperSql.contains("FROM")) {
            score -= 0.3;
        }
        if ((openParens = sql.chars().filter(ch -> ch == 40).count()) != (closeParens = sql.chars().filter(ch -> ch == 41).count())) {
            score -= 0.2;
        }
        if ((singleQuotes = sql.chars().filter(ch -> ch == 39).count()) % 2L != 0L) {
            score -= 0.2;
        }
        return Math.max(0.0, score);
    }

    private double validateSqlSecurity(String sql) {
        String[] injectionPatterns;
        String[] dangerousKeywords;
        if (sql == null) {
            return 0.0;
        }
        double score = 1.0;
        String upperSql = sql.toUpperCase();
        for (String keyword : dangerousKeywords = new String[]{"DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "CREATE", "TRUNCATE"}) {
            if (!upperSql.contains(keyword)) continue;
            score -= 0.3;
            logger.warn("\u68c0\u6d4b\u5230\u6f5c\u5728\u5371\u9669SQL\u64cd\u4f5c: {}", (Object)keyword);
        }
        for (String pattern : injectionPatterns = new String[]{"--", "/*", "*/", "UNION", "OR 1=1", "OR '1'='1'"}) {
            if (!upperSql.contains(pattern.toUpperCase())) continue;
            score -= 0.2;
            logger.warn("\u68c0\u6d4b\u5230\u6f5c\u5728SQL\u6ce8\u5165\u6a21\u5f0f: {}", (Object)pattern);
        }
        return Math.max(0.0, score);
    }

    private double evaluateSqlPerformance(String sql) {
        if (sql == null) {
            return 0.0;
        }
        double score = 1.0;
        String upperSql = sql.toUpperCase();
        if (upperSql.contains("SELECT *")) {
            score -= 0.2;
            logger.warn("\u68c0\u6d4b\u5230SELECT *\uff0c\u5efa\u8bae\u660e\u786e\u6307\u5b9a\u5b57\u6bb5");
        }
        if (!upperSql.contains("WHERE")) {
            score -= 0.3;
            logger.warn("\u67e5\u8be2\u7f3a\u5c11WHERE\u6761\u4ef6\uff0c\u53ef\u80fd\u5f71\u54cd\u6027\u80fd");
        }
        return Math.max(0.0, score);
    }

    private String performFinalValidation(String sql) {
        if (sql == null || ((String)sql).trim().isEmpty()) {
            throw new IllegalArgumentException("\u751f\u6210\u7684SQL\u4e3a\u7a7a");
        }
        if (!((String)(sql = ((String)sql).trim())).endsWith(";")) {
            sql = (String)sql + ";";
        }
        if (this.validateSqlSecurity((String)sql) < 0.5) {
            logger.warn("\u751f\u6210\u7684SQL\u5b58\u5728\u5b89\u5168\u98ce\u9669\uff0c\u4f46\u7ee7\u7eed\u6267\u884c");
        }
        return sql;
    }

    private Map<String, Object> handleUnsatisfiedRecallInfo(OverAllState state, String recallInfoSatisfyRequirement) {
        int sqlGenerateCount = StateUtils.getObjectValue(state, "SQL_GENERATE_COUNT", Integer.class, 0) + 1;
        logger.info(sqlGenerateCount == 1 ? "First time generating SQL" : "SQL generation count: {}", (Object)sqlGenerateCount);
        if (sqlGenerateCount <= 3) {
            return this.buildRetryResult(state, recallInfoSatisfyRequirement, sqlGenerateCount);
        }
        logger.info("Recall information doesn't satisfy requirements, retry limit reached, ending SQL generation");
        return Map.of("result", recallInfoSatisfyRequirement, "SQL_GENERATE_OUTPUT", "__END__", "SQL_GENERATE_COUNT", 0);
    }

    private Map<String, Object> buildRetryResult(OverAllState state, String recallInfoSatisfyRequirement, int sqlGenerateCount) {
        logger.info("Recall information doesn't satisfy requirements, starting to regenerate SQL");
        HashMap<String, Object> result = new HashMap<String, Object>();
        result.put("SQL_GENERATE_COUNT", sqlGenerateCount);
        result.put("SQL_GENERATE_OUTPUT", "SQL_GENERATE_SCHEMA_MISSING");
        String newAdvice = StateUtils.getStringValue(state, "SQL_GENERATE_SCHEMA_MISSING_ADVICE", "") + (StateUtils.hasValue(state, "SQL_GENERATE_SCHEMA_MISSING_ADVICE") ? "\n" : "") + recallInfoSatisfyRequirement;
        result.put("SQL_GENERATE_SCHEMA_MISSING_ADVICE", newAdvice);
        if (!StateUtils.hasValue(state, "SQL_GENERATE_SCHEMA_MISSING_ADVICE")) {
            logger.info("Recall information doesn't satisfy requirements, need to supplement Schema information");
        }
        return result;
    }

    private static class SqlQualityScore {
        double syntaxScore = 0.0;
        double securityScore = 0.0;
        double performanceScore = 0.0;
        double totalScore = 0.0;

        private SqlQualityScore() {
        }
    }
}

