package com.google.adk.flows.llmflows;

import com.google.adk.Telemetry;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.CallbackContext;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.events.Event;
import com.google.adk.exceptions.LlmCallsLimitExceededException;
import com.google.adk.flows.BaseFlow;
import com.google.adk.flows.llmflows.RequestProcessor;
import com.google.adk.flows.llmflows.ResponseProcessor;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.BaseLlmConnection;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionResponse;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.context.Scope;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Scheduler;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.observers.DisposableCompletableObserver;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/google/adk/flows/llmflows/BaseLlmFlow.class */
public abstract class BaseLlmFlow implements BaseFlow {
    private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class);
    protected final List<RequestProcessor> requestProcessors;
    protected final List<ResponseProcessor> responseProcessors;

    public BaseLlmFlow(List<RequestProcessor> list, List<ResponseProcessor> list2) {
        this.requestProcessors = list;
        this.responseProcessors = list2;
    }

    protected RequestProcessor.RequestProcessingResult preprocess(InvocationContext invocationContext, LlmRequest llmRequest) {
        LlmRequest llmRequest2 = llmRequest;
        ArrayList arrayList = new ArrayList();
        Iterator<RequestProcessor> it = this.requestProcessors.iterator();
        while (it.hasNext()) {
            RequestProcessor.RequestProcessingResult requestProcessingResult = (RequestProcessor.RequestProcessingResult) it.next().processRequest(invocationContext, llmRequest2).blockingGet();
            if (requestProcessingResult.events() != null) {
                arrayList.add(requestProcessingResult.events());
            }
            llmRequest2 = requestProcessingResult.updatedRequest();
        }
        LlmAgent llmAgent = (LlmAgent) invocationContext.agent();
        LlmRequest.Builder builder = llmRequest2.toBuilder();
        Iterator<BaseTool> it2 = llmAgent.tools().iterator();
        while (it2.hasNext()) {
            it2.next().processLlmRequest(builder, ToolContext.builder(invocationContext).build());
        }
        return RequestProcessor.RequestProcessingResult.create(builder.build(), Iterables.concat(arrayList));
    }

    protected Single<ResponseProcessor.ResponseProcessingResult> postprocess(InvocationContext invocationContext, Event event, LlmRequest llmRequest, LlmResponse llmResponse) {
        ArrayList arrayList = new ArrayList();
        LlmResponse llmResponse2 = llmResponse;
        Iterator<ResponseProcessor> it = this.responseProcessors.iterator();
        while (it.hasNext()) {
            ResponseProcessor.ResponseProcessingResult responseProcessingResult = (ResponseProcessor.ResponseProcessingResult) it.next().processResponse(invocationContext, llmResponse2).blockingGet();
            if (responseProcessingResult.events() != null) {
                arrayList.add(responseProcessingResult.events());
            }
            llmResponse2 = responseProcessingResult.updatedResponse();
        }
        LlmResponse llmResponse3 = llmResponse2;
        if (!llmResponse3.content().isPresent() && !llmResponse3.errorCode().isPresent() && !llmResponse3.interrupted().orElse(false).booleanValue() && !llmResponse3.turnComplete().orElse(false).booleanValue()) {
            return Single.just(ResponseProcessor.ResponseProcessingResult.create(llmResponse3, Iterables.concat(arrayList), Optional.empty()));
        }
        logger.debug("Response after processors: {}", llmResponse3);
        Event buildModelResponseEvent = buildModelResponseEvent(event, llmRequest, llmResponse3);
        arrayList.add(Collections.singleton(buildModelResponseEvent));
        logger.debug("Model response event: {}", buildModelResponseEvent.toJson());
        return (buildModelResponseEvent.functionCalls().isEmpty() ? Maybe.empty() : Functions.handleFunctionCalls(invocationContext, buildModelResponseEvent, llmRequest.tools())).map((v0) -> {
            return Optional.of(v0);
        }).defaultIfEmpty(Optional.empty()).map(optional -> {
            Optional<String> empty = Optional.empty();
            if (optional.isPresent()) {
                Event event2 = (Event) optional.get();
                logger.debug("Function call event generated: {}", event2.toJson());
                arrayList.add(Collections.singleton(event2));
                empty = event2.actions().transferToAgent();
            }
            return ResponseProcessor.ResponseProcessingResult.create(llmResponse3, Iterables.concat(arrayList), empty);
        });
    }

    private Flowable<LlmResponse> callLlm(InvocationContext invocationContext, LlmRequest llmRequest, Event event) {
        LlmAgent llmAgent = (LlmAgent) invocationContext.agent();
        return handleBeforeModelCallback(invocationContext, llmRequest, event).flatMapPublisher(optional -> {
            if (optional.isPresent()) {
                return Flowable.just((LlmResponse) optional.get());
            }
            BaseLlm llm = llmAgent.resolvedModel().model().isPresent() ? llmAgent.resolvedModel().model().get() : LlmRegistry.getLlm(llmAgent.resolvedModel().modelName().get());
            return Flowable.defer(() -> {
                Span startSpan = Telemetry.getTracer().spanBuilder("call_llm").startSpan();
                Scope makeCurrent = startSpan.makeCurrent();
                try {
                    Flowable doOnError = llm.generateContent(llmRequest, invocationContext.runConfig().streamingMode() == RunConfig.StreamingMode.SSE).doOnNext(llmResponse -> {
                        Scope makeCurrent2 = startSpan.makeCurrent();
                        try {
                            Telemetry.traceCallLlm(invocationContext, event.id(), llmRequest, llmResponse);
                            if (makeCurrent2 != null) {
                                makeCurrent2.close();
                            }
                        } catch (Throwable th) {
                            if (makeCurrent2 != null) {
                                try {
                                    makeCurrent2.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    }).doOnError(th -> {
                        startSpan.setStatus(StatusCode.ERROR, th.getMessage());
                        startSpan.recordException(th);
                    });
                    Objects.requireNonNull(startSpan);
                    Flowable doFinally = doOnError.doFinally(startSpan::end);
                    if (makeCurrent != null) {
                        makeCurrent.close();
                    }
                    return doFinally;
                } catch (Throwable th2) {
                    if (makeCurrent != null) {
                        try {
                            makeCurrent.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    }
                    throw th2;
                }
            }).concatMap(llmResponse -> {
                return handleAfterModelCallback(invocationContext, llmResponse, event).toFlowable();
            });
        });
    }

    private Single<Optional<LlmResponse>> handleBeforeModelCallback(InvocationContext invocationContext, LlmRequest llmRequest, Event event) {
        LlmAgent llmAgent = (LlmAgent) invocationContext.agent();
        Event build = event.toBuilder().build();
        return (Single) llmAgent.beforeModelCallback().map(beforeModelCallback -> {
            return beforeModelCallback.call(new CallbackContext(invocationContext, build.actions()), llmRequest).map((v0) -> {
                return Optional.of(v0);
            }).defaultIfEmpty(Optional.empty());
        }).orElse(Single.just(Optional.empty()));
    }

    private Single<LlmResponse> handleAfterModelCallback(InvocationContext invocationContext, LlmResponse llmResponse, Event event) {
        LlmAgent llmAgent = (LlmAgent) invocationContext.agent();
        Event build = event.toBuilder().content(llmResponse.content()).build();
        return (Single) llmAgent.afterModelCallback().map(afterModelCallback -> {
            return afterModelCallback.call(new CallbackContext(invocationContext, build.actions()), llmResponse).defaultIfEmpty(llmResponse);
        }).orElse(Single.just(llmResponse));
    }

    private Flowable<Event> runOneStep(InvocationContext invocationContext) {
        RequestProcessor.RequestProcessingResult preprocess = preprocess(invocationContext, LlmRequest.builder().build());
        LlmRequest updatedRequest = preprocess.updatedRequest();
        Iterable<Event> events = preprocess.events();
        logger.debug("Pre-processing result: {}", preprocess);
        if (invocationContext.endInvocation()) {
            logger.debug("End invocation requested during preprocessing.");
            return Flowable.fromIterable(events);
        }
        try {
            invocationContext.incrementLlmCallsCount();
            Event build = Event.builder().id(Event.generateEventId()).invocationId(invocationContext.invocationId()).author(invocationContext.agent().name()).branch(invocationContext.branch()).build();
            logger.debug("Starting LLM call with request: {}", updatedRequest);
            return callLlm(invocationContext, updatedRequest, build).concatMap(llmResponse -> {
                logger.debug("Processing LlmResponse with Event ID: {}", build.id());
                logger.debug("LLM response for current step: {}", llmResponse);
                return postprocess(invocationContext, build, updatedRequest, llmResponse).doOnSuccess(responseProcessingResult -> {
                    String id = build.id();
                    build.setId(Event.generateEventId());
                    logger.debug("Updated mutableEventTemplate ID from {} to {} for next LlmResponse", id, build.id());
                }).toFlowable();
            }).concatMap(responseProcessingResult -> {
                logger.debug("Post-processing result: {}", responseProcessingResult);
                Flowable fromIterable = Flowable.fromIterable(responseProcessingResult.events());
                if (!responseProcessingResult.transferToAgent().isPresent()) {
                    return fromIterable;
                }
                String str = responseProcessingResult.transferToAgent().get();
                logger.debug("Transferring to agent: {}", str);
                BaseAgent findAgent = invocationContext.agent().rootAgent().findAgent(str);
                if (findAgent != null) {
                    return fromIterable.concatWith(Flowable.defer(() -> {
                        return findAgent.runAsync(invocationContext);
                    }));
                }
                String str2 = "Agent not found for transfer: " + str;
                logger.error(str2);
                return fromIterable.concatWith(Flowable.error(new IllegalStateException(str2)));
            }).startWithIterable(events);
        } catch (LlmCallsLimitExceededException e) {
            logger.error("LLM calls limit exceeded.", e);
            return Flowable.fromIterable(events).concatWith(Flowable.error(e));
        }
    }

    @Override // com.google.adk.flows.BaseFlow
    public Flowable<Event> run(InvocationContext invocationContext) {
        Flowable cache = runOneStep(invocationContext).cache();
        return cache.concatWith(cache.toList().flatMapPublisher(list -> {
            if (list.isEmpty() || ((Event) Iterables.getLast(list)).finalResponse()) {
                logger.debug("Ending flow execution based on final response or empty event list.");
                return Flowable.empty();
            }
            logger.debug("Continuing to next step of the flow.");
            return Flowable.defer(() -> {
                return run(invocationContext);
            });
        }));
    }

    @Override // com.google.adk.flows.BaseFlow
    public Flowable<Event> runLive(InvocationContext invocationContext) {
        LlmRequest build = LlmRequest.builder().build();
        String generateEventId = Event.generateEventId();
        RequestProcessor.RequestProcessingResult preprocess = preprocess(invocationContext, build);
        LlmRequest updatedRequest = preprocess.updatedRequest();
        if (invocationContext.endInvocation()) {
            return Flowable.fromIterable(preprocess.events());
        }
        LlmAgent llmAgent = (LlmAgent) invocationContext.agent();
        final BaseLlmConnection connect = (llmAgent.resolvedModel().model().isPresent() ? llmAgent.resolvedModel().model().get() : LlmRegistry.getLlm(llmAgent.resolvedModel().modelName().get())).connect(updatedRequest);
        Disposable subscribeWith = (updatedRequest.contents().isEmpty() ? Completable.complete() : Completable.defer(() -> {
            Span startSpan = Telemetry.getTracer().spanBuilder("send_data").startSpan();
            Scope makeCurrent = startSpan.makeCurrent();
            try {
                Completable doOnError = connect.sendHistory(updatedRequest.contents()).doOnComplete(() -> {
                    Scope makeCurrent2 = startSpan.makeCurrent();
                    try {
                        Telemetry.traceSendData(invocationContext, generateEventId, updatedRequest.contents());
                        if (makeCurrent2 != null) {
                            makeCurrent2.close();
                        }
                    } catch (Throwable th) {
                        if (makeCurrent2 != null) {
                            try {
                                makeCurrent2.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }).doOnError(th -> {
                    startSpan.setStatus(StatusCode.ERROR, th.getMessage());
                    startSpan.recordException(th);
                    Scope makeCurrent2 = startSpan.makeCurrent();
                    try {
                        Telemetry.traceSendData(invocationContext, generateEventId, updatedRequest.contents());
                        if (makeCurrent2 != null) {
                            makeCurrent2.close();
                        }
                    } catch (Throwable th) {
                        if (makeCurrent2 != null) {
                            try {
                                makeCurrent2.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                });
                Objects.requireNonNull(startSpan);
                Completable doFinally = doOnError.doFinally(startSpan::end);
                if (makeCurrent != null) {
                    makeCurrent.close();
                }
                return doFinally;
            } catch (Throwable th2) {
                if (makeCurrent != null) {
                    try {
                        makeCurrent.close();
                    } catch (Throwable th3) {
                        th2.addSuppressed(th3);
                    }
                }
                throw th2;
            }
        })).observeOn((Scheduler) llmAgent.executor().map(executor -> {
            return Schedulers.from(executor);
        }).orElse(Schedulers.io())).andThen(invocationContext.liveRequestQueue().get().get().concatMapCompletable(liveRequest -> {
            if (liveRequest.content().isPresent()) {
                return connect.sendContent(liveRequest.content().get());
            }
            if (liveRequest.blob().isPresent()) {
                return connect.sendRealtime(liveRequest.blob().get());
            }
            Objects.requireNonNull(connect);
            return Completable.fromAction(connect::close);
        })).subscribeWith(new DisposableCompletableObserver(this) { // from class: com.google.adk.flows.llmflows.BaseLlmFlow.1
            final /* synthetic */ BaseLlmFlow this$0;

            {
                this.this$0 = this;
            }

            public void onComplete() {
                connect.close();
            }

            public void onError(Throwable th) {
                connect.close(th);
            }
        });
        Event.Builder branch = Event.builder().invocationId(invocationContext.invocationId()).author(invocationContext.agent().name()).branch(invocationContext.branch());
        return connect.receive().flatMapSingle(llmResponse -> {
            return postprocess(invocationContext, branch.id(Event.generateEventId()).build(), updatedRequest, llmResponse);
        }).flatMap(responseProcessingResult -> {
            Publisher fromIterable = Flowable.fromIterable(responseProcessingResult.events());
            if (responseProcessingResult.transferToAgent().isPresent()) {
                BaseAgent findAgent = invocationContext.agent().rootAgent().findAgent(responseProcessingResult.transferToAgent().get());
                if (findAgent == null) {
                    throw new IllegalStateException("Agent not found: " + responseProcessingResult.transferToAgent().get());
                }
                fromIterable = Flowable.concat(fromIterable, findAgent.runLive(invocationContext));
            }
            return fromIterable;
        }).doOnNext(event -> {
            ImmutableList<FunctionResponse> functionResponses = event.functionResponses();
            if (!functionResponses.isEmpty()) {
                invocationContext.liveRequestQueue().get().content(event.content().get());
            }
            if (functionResponses.stream().anyMatch(functionResponse -> {
                return ((String) functionResponse.name().orElse("")).equals("transferToAgent");
            })) {
                subscribeWith.dispose();
                connect.close();
            }
        }).startWithIterable(preprocess.events());
    }

    private Event buildModelResponseEvent(Event event, LlmRequest llmRequest, LlmResponse llmResponse) {
        Event build = event.toBuilder().content(llmResponse.content()).partial(llmResponse.partial()).errorCode(llmResponse.errorCode()).errorMessage(llmResponse.errorMessage()).interrupted(llmResponse.interrupted()).turnComplete(llmResponse.turnComplete()).groundingMetadata(llmResponse.groundingMetadata()).build();
        ImmutableList<FunctionCall> functionCalls = build.functionCalls();
        if (!functionCalls.isEmpty()) {
            Functions.populateClientFunctionCallId(build);
            Set<String> longRunningFunctionCalls = Functions.getLongRunningFunctionCalls(functionCalls, llmRequest.tools());
            if (!longRunningFunctionCalls.isEmpty()) {
                build.setLongRunningToolIds(Optional.of(longRunningFunctionCalls));
            }
        }
        return build;
    }
}
