/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.memory.jdbc;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

public abstract class JdbcChatMemoryRepository
implements ChatMemoryRepository {
    public static final String TABLE_NAME = "ai_chat_memory";
    private static final String QUERY_GET_IDS = "SELECT DISTINCT conversation_id FROM ai_chat_memory\n";
    private static final String QUERY_ADD = "INSERT INTO ai_chat_memory (conversation_id, content, type, \"timestamp\") VALUES (?, ?, ?, ?)\n";
    private static final String QUERY_GET = "SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY \"timestamp\"\n";
    private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
    private final JdbcTemplate jdbcTemplate;

    public JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) {
        Assert.notNull((Object)jdbcTemplate, (String)"jdbcTemplate cannot be null");
        this.jdbcTemplate = jdbcTemplate;
        this.checkAndCreateTable();
    }

    private void checkAndCreateTable() {
        if (!((Boolean)this.jdbcTemplate.query(this.hasTableSql(TABLE_NAME), ResultSet::next)).booleanValue()) {
            this.jdbcTemplate.execute(this.createTableSql(TABLE_NAME));
        }
    }

    public List<String> findConversationIds() {
        List<String> conversationIds = (List<String>)this.jdbcTemplate.query(QUERY_GET_IDS, rs -> {
            ArrayList<String> ids = new ArrayList<String>();
            while (rs.next()) {
                ids.add(rs.getString(1));
            }
            return ids;
        });
        return conversationIds != null ? conversationIds : List.of();
    }

    public List<Message> findByConversationId(String conversationId) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        return this.jdbcTemplate.query(this.getGetSql(), (RowMapper)new MessageRowMapper(), new Object[]{conversationId});
    }

    public void saveAll(String conversationId, List<Message> messages) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        Assert.notNull(messages, (String)"messages cannot be null");
        Assert.noNullElements(messages, (String)"messages cannot contain null elements");
        this.deleteByConversationId(conversationId);
        this.jdbcTemplate.batchUpdate(this.getAddSql(), (BatchPreparedStatementSetter)new AddBatchPreparedStatement(conversationId, messages));
    }

    public void deleteByConversationId(String conversationId) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        this.jdbcTemplate.update(QUERY_CLEAR, new Object[]{conversationId});
    }

    protected String getAddSql() {
        return QUERY_ADD;
    }

    protected String getGetSql() {
        return QUERY_GET;
    }

    protected abstract String hasTableSql(String var1);

    protected abstract String createTableSql(String var1);

    private static class MessageRowMapper
    implements RowMapper<Message> {
        private MessageRowMapper() {
        }

        @Nullable
        public Message mapRow(ResultSet rs, int i) throws SQLException {
            String content = rs.getString(1);
            MessageType type = MessageType.valueOf((String)rs.getString(2));
            return switch (type) {
                default -> throw new IncompatibleClassChangeError();
                case MessageType.USER -> new UserMessage(content);
                case MessageType.ASSISTANT -> new AssistantMessage(content);
                case MessageType.SYSTEM -> new SystemMessage(content);
                case MessageType.TOOL -> new ToolResponseMessage(List.of());
            };
        }
    }

    private record AddBatchPreparedStatement(String conversationId, List<Message> messages, AtomicLong instantSeq) implements BatchPreparedStatementSetter
    {
        private AddBatchPreparedStatement(String conversationId, List<Message> messages) {
            this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli()));
        }

        public void setValues(PreparedStatement ps, int i) throws SQLException {
            Message message = this.messages.get(i);
            ps.setString(1, this.conversationId);
            ps.setString(2, message.getText());
            ps.setString(3, message.getMessageType().name());
            ps.setTimestamp(4, new Timestamp(this.instantSeq.getAndIncrement()));
        }

        public int getBatchSize() {
            return this.messages.size();
        }
    }
}

