/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.controller;

import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.request.SchemaInitRequest;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.simple.SimpleVectorStoreService;
import com.alibaba.fastjson.JSON;
import jakarta.servlet.http.HttpServletResponse;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

@RestController
@RequestMapping(value={"nl2sql"})
public class Nl2sqlForGraphController {
    private static final Logger logger = LoggerFactory.getLogger(Nl2sqlForGraphController.class);
    private final CompiledGraph compiledGraph;
    private final SimpleVectorStoreService simpleVectorStoreService;
    private final DatasourceService datasourceService;

    public Nl2sqlForGraphController(@Qualifier(value="nl2sqlGraph") StateGraph stateGraph, SimpleVectorStoreService simpleVectorStoreService, DatasourceService datasourceService) throws GraphStateException {
        this.compiledGraph = stateGraph.compile();
        this.compiledGraph.setMaxIterations(100);
        this.simpleVectorStoreService = simpleVectorStoreService;
        this.datasourceService = datasourceService;
    }

    @GetMapping(value={"/search"})
    public String search(@RequestParam String query, @RequestParam String dataSetId, @RequestParam String agentId) throws Exception {
        DbConfig dbConfig = this.getDbConfigForAgent(Integer.valueOf(agentId));
        SchemaInitRequest schemaInitRequest = new SchemaInitRequest();
        schemaInitRequest.setDbConfig(dbConfig);
        schemaInitRequest.setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories"));
        this.simpleVectorStoreService.schema(schemaInitRequest);
        Optional invoke = this.compiledGraph.invoke(Map.of("input", query, "agentId", dataSetId, "agentId", agentId));
        OverAllState overAllState = (OverAllState)invoke.get();
        return overAllState.value("result").get().toString();
    }

    @GetMapping(value={"/init"})
    public void init(@RequestParam(required=false, defaultValue="1") Integer agentId) throws Exception {
        DbConfig dbConfig = this.getDbConfigForAgent(agentId);
        SchemaInitRequest schemaInitRequest = new SchemaInitRequest();
        schemaInitRequest.setDbConfig(dbConfig);
        schemaInitRequest.setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories"));
        this.simpleVectorStoreService.schema(schemaInitRequest);
    }

    private DbConfig getDbConfigForAgent(Integer agentId) {
        try {
            List agentDatasources = this.datasourceService.getAgentDatasources(agentId);
            AgentDatasource activeDatasource = agentDatasources.stream().filter(ad -> ad.getIsActive() == 1).findFirst().orElseThrow(() -> new RuntimeException("\u667a\u80fd\u4f53 " + agentId + " \u672a\u914d\u7f6e\u542f\u7528\u7684\u6570\u636e\u6e90"));
            return this.createDbConfigFromDatasource(activeDatasource.getDatasource());
        }
        catch (Exception e) {
            logger.error("Failed to get agent datasource config for agent: {}", (Object)agentId, (Object)e);
            throw new RuntimeException("\u83b7\u53d6\u667a\u80fd\u4f53\u6570\u636e\u6e90\u914d\u7f6e\u5931\u8d25: " + e.getMessage(), e);
        }
    }

    private DbConfig createDbConfigFromDatasource(Datasource datasource) {
        DbConfig dbConfig = new DbConfig();
        dbConfig.setUrl(datasource.getConnectionUrl());
        dbConfig.setUsername(datasource.getUsername());
        dbConfig.setPassword(datasource.getPassword());
        if ("mysql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("mysql");
        } else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
            dbConfig.setConnectionType("jdbc");
            dbConfig.setDialectType("postgresql");
        } else {
            throw new RuntimeException("\u4e0d\u652f\u6301\u7684\u6570\u636e\u5e93\u7c7b\u578b: " + datasource.getType());
        }
        dbConfig.setSchema(datasource.getDatabaseName());
        return dbConfig;
    }

    @GetMapping(value={"/stream/search"}, produces={"text/event-stream"})
    public Flux<ServerSentEvent<String>> streamSearch(@RequestParam String query, @RequestParam String agentId, HttpServletResponse response) throws Exception {
        response.setCharacterEncoding("UTF-8");
        response.setContentType("text/event-stream");
        response.setHeader("Cache-Control", "no-cache");
        response.setHeader("Connection", "keep-alive");
        response.setHeader("Access-Control-Allow-Origin", "*");
        response.setHeader("Access-Control-Allow-Headers", "Cache-Control");
        logger.info("Starting stream search for query: {} with agentId: {}", (Object)query, (Object)agentId);
        Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer();
        AsyncGenerator generator = this.compiledGraph.stream(Map.of("input", query, "agentId", agentId));
        CompletableFuture.runAsync(() -> {
            try {
                ((CompletableFuture)generator.forEachAsync(output -> {
                    try {
                        logger.debug("Received output: {}", (Object)output.getClass().getSimpleName());
                        if (output instanceof StreamingOutput) {
                            StreamingOutput streamingOutput = (StreamingOutput)output;
                            String chunk = streamingOutput.chunk();
                            if (chunk != null && !chunk.trim().isEmpty()) {
                                logger.debug("Emitting chunk: {}", (Object)chunk);
                                ServerSentEvent event = ServerSentEvent.builder((Object)JSON.toJSONString((Object)chunk)).build();
                                sink.tryEmitNext((Object)event);
                            } else {
                                logger.warn("ReceFenerator: mapResult called, finalResultived null or empty chunk from streaming output");
                            }
                        } else {
                            logger.debug("Non-streaming output received: {}", output);
                        }
                    }
                    catch (Exception e) {
                        logger.error("Error processing streaming output: ", (Throwable)e);
                    }
                }).thenAccept(v -> {
                    logger.info("Stream processing completed successfully");
                    sink.tryEmitNext((Object)ServerSentEvent.builder((Object)"complete").event("complete").build());
                    sink.tryEmitComplete();
                })).exceptionally(e -> {
                    logger.error("Error in stream processing: ", e);
                    sink.tryEmitNext((Object)ServerSentEvent.builder((Object)("error: " + e.getMessage())).event("error").build());
                    sink.tryEmitComplete();
                    return null;
                });
            }
            catch (Exception e2) {
                logger.error("Error starting stream processing: ", (Throwable)e2);
                sink.tryEmitError((Throwable)e2);
            }
        });
        return sink.asFlux().doOnSubscribe(subscription -> logger.info("Client subscribed to stream")).doOnCancel(() -> logger.info("Client disconnected from stream")).doOnError(e -> logger.error("Error occurred during streaming: ", e)).doOnComplete(() -> logger.info("Stream completed successfully"));
    }
}

