/*
 * Copyright 2024-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.cloud.ai.dashscope.api;

import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletion;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionChunk;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionFinishReason;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionMessage;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionMessage.ChatCompletionFunction;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionMessage.Role;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionMessage.ToolCall;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionOutput;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.ChatCompletionOutput.Choice;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec.TokenUsage;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.List;

/**
 * Helper class to support Streaming function calling. It can merge the streamed
 * ChatCompletionChunk in case of function calling message.
 *
 * @author Ken
 */
public class DashScopeAiStreamFunctionCallingHelper {

	private Boolean incrementalOutput = false;

	public DashScopeAiStreamFunctionCallingHelper() {
	}

	public DashScopeAiStreamFunctionCallingHelper(Boolean incrementalOutput) {
		this.incrementalOutput = incrementalOutput;
	}

	/**
	 * Merge the previous and current ChatCompletionChunk into a single one.
	 * @param previous the previous ChatCompletionChunk
	 * @param current the current ChatCompletionChunk
	 * @return the merged ChatCompletionChunk
	 */
	public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {
		if (previous == null) {
			return current;
		}

		String id = (current.requestId() != null ? current.requestId() : previous.requestId());
		TokenUsage usage = (current.usage() != null ? current.usage() : previous.usage());

		Choice previousChoice0 = previous.output() == null ? null
				: CollectionUtils.isEmpty(previous.output().choices()) ? null : previous.output().choices().get(0);
		Choice currentChoice0 = current.output() == null ? null
				: CollectionUtils.isEmpty(current.output().choices()) ? null : current.output().choices().get(0);

		// compatibility of incremental_output false for streaming function call
		if (!incrementalOutput && isStreamingToolFunctionCall(current)) {
			if (!isStreamingToolFunctionCallFinish(current)) {
				return new ChatCompletionChunk(id, new ChatCompletionOutput(null, List.of(), null), usage, null);
			}
			else {
				List<Choice> choices = currentChoice0 == null ? List.of() : List.of(currentChoice0);
				return new ChatCompletionChunk(id, new ChatCompletionOutput(null, choices, null), usage, null);
			}
		}

		Choice choice = merge(previousChoice0, currentChoice0);
		List<Choice> chunkChoices = choice == null ? List.of() : List.of(choice);
		return new ChatCompletionChunk(id, new ChatCompletionOutput(null, chunkChoices, null), usage, null);
	}

	private Choice merge(Choice previous, Choice current) {
		if (previous == null) {
			return current;
		}
		if (current == null) {
			return null;
		}

		ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason()
				: previous.finishReason());
		ChatCompletionMessage message = merge(previous.message(), current.message());
		DashScopeApiSpec.ChatCompletionLogprobs logprobs = (current.logprobs() != null ? current.logprobs()
				: previous.logprobs());

		return new Choice(finishReason, message, logprobs);
	}

	private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {

        // response
		Object content = (current.content() != null ? current.content()
				: (previous.content() != null) ? previous.content() : "");
		Role role = (current.role() != null ? current.role() : previous.role());
		role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
		String name = (StringUtils.hasText(current.name()) ? current.name() : previous.name());
		String toolCallId = (StringUtils.hasText(current.toolCallId()) ? current.toolCallId() : previous.toolCallId());
		String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
				: previous.reasoningContent());
		Boolean partial = (current.partial() != null ? current.partial() : previous.partial());
        List<DashScopeApiSpec.ChatCompletionAnnotations> annotations = (current.annotations() != null ? current.annotations() : previous.annotations());
        String status = (current.status() != null ? current.status() : previous.status());
        String phase = (current.phase() != null ? current.phase() : previous.phase());

		List<ToolCall> toolCalls = new ArrayList<>();
		ToolCall lastPreviousTooCall = null;
		if (previous.toolCalls() != null) {
			lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
			if (previous.toolCalls().size() > 1) {
				toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1));
			}
		}
		if (!CollectionUtils.isEmpty(current.toolCalls())) {
			if (current.toolCalls().size() > 1) {
				throw new IllegalStateException("Currently only one tool call is supported per message!");
			}
			var currentToolCall = current.toolCalls().iterator().next();
			if (StringUtils.hasText(currentToolCall.id())) {
				if (lastPreviousTooCall != null) {
					toolCalls.add(lastPreviousTooCall);
				}
				toolCalls.add(currentToolCall);
			}
			else {
				toolCalls.add(merge(lastPreviousTooCall, currentToolCall));
			}
		}
		else {
			if (lastPreviousTooCall != null) {
				toolCalls.add(lastPreviousTooCall);
			}
		}
		return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, reasoningContent, partial, phase, annotations, status);
	}

	private ToolCall merge(ToolCall previous, ToolCall current) {

        if (previous == null) {
            return current;
		}

        String id = (StringUtils.hasText(current.id()) ? current.id() : previous.id());
		String type = (StringUtils.hasText(current.type()) ? current.type() : previous.type());
		Integer index = (current.index() != 0 ? current.index() : previous.index());


		ChatCompletionFunction function = merge(previous.function(), current.function());
		return new ToolCall(id, type, function, index);
	}

	private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {
		if (previous == null) {
			return current;
		}
		String name = (StringUtils.hasText(current.name()) ? current.name() : previous.name());
		StringBuilder arguments = new StringBuilder();
		if (previous.arguments() != null) {
			arguments.append(previous.arguments());
		}
		if (current.arguments() != null) {
			arguments.append(current.arguments());
		}
		return new ChatCompletionFunction(name, arguments.toString());
	}

	/**
	 * @param chatCompletion the ChatCompletionChunk to check
	 * @return true if the ChatCompletionChunk is a streaming tool function call.
	 */
	public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) {
		var choice = checkChatCompletionChunk(chatCompletion);
		if (choice == null) {
			return false;
		}
		return !CollectionUtils.isEmpty(choice.message().toolCalls());
	}

	/**
	 * @param chatCompletion the ChatCompletionChunk to check
	 * @return true if the ChatCompletionChunk is a streaming tool function call and it is
	 * the last one.
	 */
	public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) {
		var choice = checkChatCompletionChunk(chatCompletion);
		if (choice == null) {
			return false;
		}
		return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
	}

	/**
	 * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
	 * @param chunk the ChatCompletionChunk to convert
	 * @return the ChatCompletion
	 */
	public ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
		return new ChatCompletion(chunk.requestId(), chunk.output(), chunk.usage());
	}

	private Choice checkChatCompletionChunk(ChatCompletionChunk chatCompletion) {
		if (chatCompletion == null || chatCompletion.output() == null
				|| CollectionUtils.isEmpty(chatCompletion.output().choices())) {
			return null;
		}

		var choice = chatCompletion.output().choices().get(0);
		if (choice == null || choice.message() == null) {
			return null;
		}
		return choice;
	}

}
