/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.prompt.PromptConstant;
import com.alibaba.cloud.ai.prompt.PromptHelper;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;

public class PlannerNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(PlannerNode.class);
    private final ChatClient chatClient;

    public PlannerNode(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        String input = (String)state.value("input").orElseThrow();
        String processedQuery = StateUtils.getStringValue(state, "QUERY_REWRITE_NODE_OUTPUT", input);
        logger.info("Using processed query for planning: {}", (Object)processedQuery);
        Boolean onlyNl2sql = (Boolean)state.value("IS_ONLY_NL2SQL", (Object)false);
        String validationError = StateUtils.getStringValue(state, "PLAN_VALIDATION_ERROR", null);
        if (validationError != null) {
            logger.info("Regenerating plan with user feedback: {}", (Object)validationError);
        } else {
            logger.info("Generating initial plan");
        }
        String businessKnowledge = state.value("BUSINESS_KNOWLEDGE").orElse("");
        String semanticModel = state.value("SEMANTIC_MODEL").orElse("");
        SchemaDTO schemaDTO = StateUtils.getObjectValue(state, "TABLE_RELATION_OUTPUT", SchemaDTO.class);
        String schemaStr = PromptHelper.buildMixMacSqlDbPrompt(schemaDTO, true);
        String userPrompt = this.buildUserPrompt(processedQuery, validationError, state);
        Map<String, String> params = Map.of("user_question", userPrompt, "schema", schemaStr, "business_knowledge", businessKnowledge, "semantic_model", semanticModel, "plan_validation_error", this.formatValidationError(validationError));
        String plannerPrompt = (onlyNl2sql != false ? PromptConstant.getPlannerNl2sqlOnlyTemplate() : PromptConstant.getPlannerPromptTemplate()).render(params);
        Flux chatResponseFlux = this.chatClient.prompt().user(plannerPrompt).stream().chatResponse();
        Flux<GraphResponse<StreamingOutput>> generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, v -> Map.of("PLANNER_NODE_OUTPUT", v), (Flux<ChatResponse>)chatResponseFlux, StreamResponseType.PLAN_GENERATION);
        return Map.of("PLANNER_NODE_OUTPUT", generator);
    }

    private String buildUserPrompt(String input, String validationError, OverAllState state) {
        if (validationError == null) {
            return input;
        }
        String previousPlan = StateUtils.getStringValue(state, "PLANNER_NODE_OUTPUT", "");
        return String.format("IMPORTANT: User rejected previous plan with feedback: \"%s\"\n\nOriginal question: %s\n\nPrevious rejected plan:\n%s\n\nCRITICAL: Generate new plan incorporating user feedback (\"%s\")", validationError, input, previousPlan, validationError);
    }

    private String formatValidationError(String validationError) {
        return validationError != null ? String.format("**USER FEEDBACK (CRITICAL)**: %s\n\n**Must incorporate this feedback.**", validationError) : "";
    }
}

