/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.tools;

import com.alibaba.cloud.ai.graph.RunnableConfig;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShellSessionManager {
    private static final Logger log = LoggerFactory.getLogger(ShellSessionManager.class);
    private static final String DONE_MARKER_PREFIX = "__LC_SHELL_DONE__";
    private static final String SESSION_INSTANCE_CONTEXT_KEY = "_SHELL_SESSION_";
    private static final String SESSION_PATH_CONTEXT_KEY = "_SHELL_PATH_";
    private final Path workspaceRoot;
    private final boolean useTemporaryWorkspace;
    private final List<String> startupCommands;
    private final List<String> shutdownCommands;
    private final long commandTimeout;
    private final long startupTimeout;
    private final long terminationTimeout;
    private final int maxOutputLines;
    private final Long maxOutputBytes;
    private final List<String> shellCommand;
    private final Map<String, String> environment;
    private final List<RedactionRule> redactionRules;

    private ShellSessionManager(Builder builder) {
        this.workspaceRoot = builder.workspaceRoot;
        this.useTemporaryWorkspace = builder.workspaceRoot == null;
        this.startupCommands = new ArrayList<String>(builder.startupCommands);
        this.shutdownCommands = new ArrayList<String>(builder.shutdownCommands);
        this.commandTimeout = builder.commandTimeout;
        this.startupTimeout = builder.startupTimeout;
        this.terminationTimeout = builder.terminationTimeout;
        this.maxOutputLines = builder.maxOutputLines;
        this.maxOutputBytes = builder.maxOutputBytes;
        this.shellCommand = new ArrayList<String>(builder.shellCommand);
        this.environment = new HashMap<String, String>(builder.environment);
        this.redactionRules = new ArrayList<RedactionRule>(builder.redactionRules);
    }

    public static Builder builder() {
        return new Builder();
    }

    public void initialize(RunnableConfig config) {
        try {
            Path workspace = this.workspaceRoot;
            if (this.useTemporaryWorkspace) {
                Path tempDir = Files.createTempDirectory("shell_tool_", new FileAttribute[0]);
                config.context().put(SESSION_PATH_CONTEXT_KEY, tempDir);
                workspace = tempDir;
            } else {
                Files.createDirectories(workspace, new FileAttribute[0]);
            }
            ShellSession session = new ShellSession(workspace, this.shellCommand, this.environment);
            session.start();
            config.context().put(SESSION_INSTANCE_CONTEXT_KEY, session);
            log.info("Started shell session in workspace: {}", (Object)workspace);
            for (String command : this.startupCommands) {
                CommandResult result = session.execute(command, this.startupTimeout, this.maxOutputLines, this.maxOutputBytes);
                if (!result.isTimedOut() && (result.getExitCode() == null || result.getExitCode() == 0)) continue;
                throw new RuntimeException("Startup command failed: " + command + ", exit code: " + result.getExitCode());
            }
        }
        catch (Exception e) {
            this.cleanup(config);
            throw new RuntimeException("Failed to initialize shell session", e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void cleanup(RunnableConfig config) {
        try {
            ShellSession session = (ShellSession)config.context().get(SESSION_INSTANCE_CONTEXT_KEY);
            if (session != null) {
                for (String command : this.shutdownCommands) {
                    try {
                        session.execute(command, this.commandTimeout, this.maxOutputLines, this.maxOutputBytes);
                    }
                    catch (Exception e) {
                        log.warn("Shutdown command failed: {}", (Object)command, (Object)e);
                    }
                }
            }
        }
        finally {
            this.doCleanup(config);
        }
    }

    private void doCleanup(RunnableConfig config) {
        Path tempDir;
        ShellSession session = (ShellSession)config.context().get(SESSION_INSTANCE_CONTEXT_KEY);
        if (session != null) {
            session.stop(this.terminationTimeout);
            config.context().remove(SESSION_INSTANCE_CONTEXT_KEY);
        }
        if ((tempDir = (Path)config.context().get(SESSION_PATH_CONTEXT_KEY)) != null) {
            try {
                this.deleteDirectory(tempDir);
            }
            catch (IOException e) {
                log.warn("Failed to delete temporary directory: {}", (Object)tempDir, (Object)e);
            }
            config.context().remove(SESSION_PATH_CONTEXT_KEY);
        }
    }

    public CommandResult executeCommand(String command, RunnableConfig config) {
        ShellSession session = (ShellSession)config.context().get(SESSION_INSTANCE_CONTEXT_KEY);
        if (session == null) {
            throw new IllegalStateException("Shell session not initialized. Call initialize() first.");
        }
        log.info("Executing shell command: {}", (Object)command);
        CommandResult result = session.execute(command, this.commandTimeout, this.maxOutputLines, this.maxOutputBytes);
        String output = result.getOutput();
        HashMap<String, List<String>> allMatches = new HashMap<String, List<String>>();
        for (RedactionRule rule : this.redactionRules) {
            RedactionResult redactionResult = rule.applyWithMatches(output);
            output = redactionResult.getRedactedContent();
            if (redactionResult.getMatches().isEmpty()) continue;
            allMatches.computeIfAbsent(rule.getPiiType(), k -> new ArrayList()).addAll(redactionResult.getMatches());
        }
        return new CommandResult(output, result.getExitCode(), result.isTimedOut(), result.isTruncatedByLines(), result.isTruncatedByBytes(), result.getTotalLines(), result.getTotalBytes(), allMatches);
    }

    public void restartSession(RunnableConfig config) {
        ShellSession session = (ShellSession)config.context().get(SESSION_INSTANCE_CONTEXT_KEY);
        if (session == null) {
            throw new IllegalStateException("Shell session not initialized.");
        }
        log.info("Restarting shell session");
        session.restart();
        for (String command : this.startupCommands) {
            session.execute(command, this.startupTimeout, this.maxOutputLines, this.maxOutputBytes);
        }
    }

    public int getMaxOutputLines() {
        return this.maxOutputLines;
    }

    public Long getMaxOutputBytes() {
        return this.maxOutputBytes;
    }

    private void deleteDirectory(Path directory) throws IOException {
        if (Files.exists(directory, new LinkOption[0])) {
            try (Stream<Path> stream = Files.walk(directory, new FileVisitOption[0]);){
                stream.sorted(Comparator.reverseOrder()).forEach(path -> {
                    try {
                        Files.delete(path);
                    }
                    catch (IOException e) {
                        log.warn("Failed to delete: {}", path, (Object)e);
                    }
                });
            }
        }
    }

    public static class Builder {
        private Path workspaceRoot;
        private final List<String> startupCommands = new ArrayList<String>();
        private final List<String> shutdownCommands = new ArrayList<String>();
        private long commandTimeout = 30000L;
        private long startupTimeout = 10000L;
        private long terminationTimeout = 5000L;
        private int maxOutputLines = 1000;
        private Long maxOutputBytes = null;
        private List<String> shellCommand = Arrays.asList("/bin/bash");
        private final Map<String, String> environment = new HashMap<String, String>();
        private final List<RedactionRule> redactionRules = new ArrayList<RedactionRule>();

        public Builder workspaceRoot(String path) {
            this.workspaceRoot = Path.of(path, new String[0]);
            return this;
        }

        public Builder workspaceRoot(Path path) {
            this.workspaceRoot = path;
            return this;
        }

        public Builder addStartupCommand(String command) {
            this.startupCommands.add(command);
            return this;
        }

        public Builder addShutdownCommand(String command) {
            this.shutdownCommands.add(command);
            return this;
        }

        public Builder setStartupCommand(List<String> commands) {
            this.startupCommands.addAll(commands);
            return this;
        }

        public Builder setShutdownCommand(List<String> commands) {
            this.shutdownCommands.addAll(commands);
            return this;
        }

        public Builder commandTimeout(long millis) {
            this.commandTimeout = millis;
            return this;
        }

        public Builder startupTimeout(long millis) {
            this.startupTimeout = millis;
            return this;
        }

        public Builder terminationTimeout(long millis) {
            this.terminationTimeout = millis;
            return this;
        }

        public Builder maxOutputLines(int lines) {
            this.maxOutputLines = lines;
            return this;
        }

        public Builder maxOutputBytes(long bytes) {
            this.maxOutputBytes = bytes;
            return this;
        }

        public Builder shellCommand(List<String> command) {
            this.shellCommand = new ArrayList<String>(command);
            return this;
        }

        public Builder environment(Map<String, String> env) {
            this.environment.putAll(env);
            return this;
        }

        public Builder addRedactionRule(RedactionRule rule) {
            this.redactionRules.add(rule);
            return this;
        }

        public ShellSessionManager build() {
            return new ShellSessionManager(this);
        }
    }

    private class ShellSession {
        private final Path workspace;
        private final List<String> command;
        private final Map<String, String> env;
        private Process process;
        private BufferedWriter stdin;
        private BlockingQueue<OutputLine> outputQueue;
        private volatile boolean terminated;

        ShellSession(Path workspace, List<String> command, Map<String, String> env) {
            this.workspace = workspace;
            this.command = command;
            this.env = env;
            this.outputQueue = new LinkedBlockingQueue<OutputLine>();
        }

        void start() throws IOException {
            ProcessBuilder pb = new ProcessBuilder(this.command);
            pb.directory(this.workspace.toFile());
            pb.environment().putAll(this.env);
            pb.redirectErrorStream(false);
            this.process = pb.start();
            this.stdin = new BufferedWriter(new OutputStreamWriter(this.process.getOutputStream()));
            new Thread(() -> {
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(this.process.getInputStream()));){
                    String line;
                    while ((line = reader.readLine()) != null) {
                        this.outputQueue.offer(new OutputLine("stdout", line));
                    }
                }
                catch (IOException e) {
                    log.debug("Stdout reader terminated", (Throwable)e);
                }
                finally {
                    this.outputQueue.offer(new OutputLine("stdout", null));
                }
            }, "shell-stdout-reader").start();
            new Thread(() -> {
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(this.process.getErrorStream()));){
                    String line;
                    while ((line = reader.readLine()) != null) {
                        this.outputQueue.offer(new OutputLine("stderr", line));
                    }
                }
                catch (IOException e) {
                    log.debug("Stderr reader terminated", (Throwable)e);
                }
                finally {
                    this.outputQueue.offer(new OutputLine("stderr", null));
                }
            }, "shell-stderr-reader").start();
        }

        void restart() {
            this.stop(ShellSessionManager.this.terminationTimeout);
            try {
                this.start();
            }
            catch (IOException e) {
                throw new RuntimeException("Failed to restart shell session", e);
            }
        }

        void stop(long timeoutMs) {
            if (this.process == null || !this.process.isAlive()) {
                return;
            }
            this.terminated = true;
            try {
                this.stdin.write("exit\n");
                this.stdin.flush();
            }
            catch (IOException e) {
                log.debug("Failed to send exit command", (Throwable)e);
            }
            try {
                if (!this.process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) {
                    this.process.destroyForcibly();
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                this.process.destroyForcibly();
            }
            try {
                this.stdin.close();
            }
            catch (IOException e) {
                log.debug("Failed to close stdin", (Throwable)e);
            }
        }

        CommandResult execute(String command, long timeoutMs, int maxOutputLines, Long maxOutputBytes) {
            if (this.process == null || !this.process.isAlive()) {
                throw new IllegalStateException("Shell session is not running");
            }
            String marker = ShellSessionManager.DONE_MARKER_PREFIX + UUID.randomUUID().toString().replace("-", "");
            long deadline = System.currentTimeMillis() + timeoutMs;
            try {
                this.outputQueue.clear();
                this.stdin.write(command);
                if (!command.endsWith("\n")) {
                    this.stdin.write("\n");
                }
                this.stdin.write(String.format("printf '%s %%s\\n' $?\n", marker));
                this.stdin.flush();
                return this.collectOutput(marker, deadline, maxOutputLines, maxOutputBytes);
            }
            catch (IOException e) {
                throw new RuntimeException("Failed to execute command", e);
            }
        }

        private CommandResult collectOutput(String marker, long deadline, int maxOutputLines, Long maxOutputBytes) {
            ArrayList<Object> lines = new ArrayList<Object>();
            int totalLines = 0;
            long totalBytes = 0L;
            boolean truncatedByLines = false;
            boolean truncatedByBytes = false;
            Integer exitCode = null;
            boolean timedOut = false;
            while (true) {
                long remaining;
                if ((remaining = deadline - System.currentTimeMillis()) <= 0L) {
                    timedOut = true;
                    log.warn("Command timed out, restarting session");
                    this.restart();
                    break;
                }
                try {
                    OutputLine outputLine = this.outputQueue.poll(remaining, TimeUnit.MILLISECONDS);
                    if (outputLine == null) {
                        timedOut = true;
                        this.restart();
                        break;
                    }
                    if (outputLine.content == null) continue;
                    String line = outputLine.content;
                    if ("stdout".equals(outputLine.source) && line.startsWith(marker)) {
                        String[] parts = line.split(" ", 2);
                        if (parts.length <= 1) break;
                        try {
                            exitCode = Integer.parseInt(parts[1].trim());
                        }
                        catch (NumberFormatException numberFormatException) {}
                        break;
                    }
                    ++totalLines;
                    Object formattedLine = line;
                    if ("stderr".equals(outputLine.source)) {
                        formattedLine = "[stderr] " + line;
                    }
                    totalBytes += (long)(((String)formattedLine).getBytes().length + 1);
                    if (totalLines <= maxOutputLines) {
                        if (maxOutputBytes == null || totalBytes <= maxOutputBytes) {
                            lines.add(formattedLine);
                            continue;
                        }
                        truncatedByBytes = true;
                        continue;
                    }
                    truncatedByLines = true;
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                }
            }
            String output = String.join((CharSequence)"\n", lines);
            return new CommandResult(output, exitCode, timedOut, truncatedByLines, truncatedByBytes, totalLines, totalBytes);
        }
    }

    public static class CommandResult {
        private final String output;
        private final Integer exitCode;
        private final boolean timedOut;
        private final boolean truncatedByLines;
        private final boolean truncatedByBytes;
        private final int totalLines;
        private final long totalBytes;
        private final Map<String, List<String>> redactionMatches;

        public CommandResult(String output, Integer exitCode, boolean timedOut, boolean truncatedByLines, boolean truncatedByBytes, int totalLines, long totalBytes) {
            this(output, exitCode, timedOut, truncatedByLines, truncatedByBytes, totalLines, totalBytes, new HashMap<String, List<String>>());
        }

        public CommandResult(String output, Integer exitCode, boolean timedOut, boolean truncatedByLines, boolean truncatedByBytes, int totalLines, long totalBytes, Map<String, List<String>> redactionMatches) {
            this.output = output;
            this.exitCode = exitCode;
            this.timedOut = timedOut;
            this.truncatedByLines = truncatedByLines;
            this.truncatedByBytes = truncatedByBytes;
            this.totalLines = totalLines;
            this.totalBytes = totalBytes;
            this.redactionMatches = new HashMap<String, List<String>>(redactionMatches);
        }

        public String getOutput() {
            return this.output;
        }

        public Integer getExitCode() {
            return this.exitCode;
        }

        public boolean isTimedOut() {
            return this.timedOut;
        }

        public boolean isTruncatedByLines() {
            return this.truncatedByLines;
        }

        public boolean isTruncatedByBytes() {
            return this.truncatedByBytes;
        }

        public int getTotalLines() {
            return this.totalLines;
        }

        public long getTotalBytes() {
            return this.totalBytes;
        }

        public Map<String, List<String>> getRedactionMatches() {
            return new HashMap<String, List<String>>(this.redactionMatches);
        }

        public boolean isSuccess() {
            return !this.timedOut && (this.exitCode == null || this.exitCode == 0);
        }
    }

    public static interface RedactionRule {
        public RedactionResult applyWithMatches(String var1);

        public String getPiiType();
    }

    public static class RedactionResult {
        private final String redactedContent;
        private final List<String> matches;

        public RedactionResult(String redactedContent, List<String> matches) {
            this.redactedContent = redactedContent;
            this.matches = new ArrayList<String>(matches);
        }

        public String getRedactedContent() {
            return this.redactedContent;
        }

        public List<String> getMatches() {
            return new ArrayList<String>(this.matches);
        }
    }

    public static class PatternRedactionRule
    implements RedactionRule {
        private final Pattern pattern;
        private final String replacement;
        private final String piiType;

        public PatternRedactionRule(String pattern, String replacement, String piiType) {
            this.pattern = Pattern.compile(pattern);
            this.replacement = replacement;
            this.piiType = piiType;
        }

        @Override
        public RedactionResult applyWithMatches(String content) {
            ArrayList<String> matches = new ArrayList<String>();
            Matcher matcher = this.pattern.matcher(content);
            while (matcher.find()) {
                matches.add(matcher.group());
            }
            String redacted = matcher.replaceAll(this.replacement);
            return new RedactionResult(redacted, matches);
        }

        @Override
        public String getPiiType() {
            return this.piiType;
        }
    }

    private static class OutputLine {
        final String source;
        final String content;

        OutputLine(String source, String content) {
            this.source = source;
            this.content = content;
        }
    }
}

