From db37c446080169758d7bb4aaf29999510ac52c5b Mon Sep 17 00:00:00 2001 From: Owen Kaplan Date: Fri, 6 Mar 2026 14:15:39 -0500 Subject: [PATCH] feat: add shell tool and tests --- .gitignore | 2 + strands-command/__init__.py | 0 strands-command/scripts/__init__.py | 1 + strands-command/scripts/python/__init__.py | 1 + strands-command/scripts/python/shell_tool.py | 310 +++++++++++++++ .../scripts/python/tests/__init__.py | 1 + .../scripts/python/tests/integ/__init__.py | 1 + .../python/tests/integ/test_shell_tool.py | 191 ++++++++++ .../scripts/python/tests/requirements.txt | 0 .../scripts/python/tests/test_shell_tool.py | 354 ++++++++++++++++++ 10 files changed, 861 insertions(+) create mode 100644 strands-command/__init__.py create mode 100644 strands-command/scripts/__init__.py create mode 100644 strands-command/scripts/python/__init__.py create mode 100644 strands-command/scripts/python/shell_tool.py create mode 100644 strands-command/scripts/python/tests/__init__.py create mode 100644 strands-command/scripts/python/tests/integ/__init__.py create mode 100644 strands-command/scripts/python/tests/integ/test_shell_tool.py create mode 100644 strands-command/scripts/python/tests/requirements.txt create mode 100644 strands-command/scripts/python/tests/test_shell_tool.py diff --git a/.gitignore b/.gitignore index 1be538f..adce3c6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ # Dependencies node_modules/ +__pycache__ + # CDK output cdk.out/ diff --git a/strands-command/__init__.py b/strands-command/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strands-command/scripts/__init__.py b/strands-command/scripts/__init__.py new file mode 100644 index 0000000..7994ee4 --- /dev/null +++ b/strands-command/scripts/__init__.py @@ -0,0 +1 @@ +# Scripts package diff --git a/strands-command/scripts/python/__init__.py b/strands-command/scripts/python/__init__.py new file mode 100644 index 0000000..6a3bd7c --- /dev/null +++ b/strands-command/scripts/python/__init__.py @@ -0,0 +1 @@ +# Python tools diff --git a/strands-command/scripts/python/shell_tool.py b/strands-command/scripts/python/shell_tool.py new file mode 100644 index 0000000..c47ab00 --- /dev/null +++ b/strands-command/scripts/python/shell_tool.py @@ -0,0 +1,310 @@ +import os +import subprocess +import time +import threading +import weakref +from strands import tool +from strands.types.tools import ToolContext + + +# Module-level session registry with automatic cleanup when agents are GC'd +_sessions = weakref.WeakKeyDictionary() + + +class ShellSession: + """Manages a persistent shell process using plain pipes. + + Architecture: + - One long-lived shell process per session + - stderr merged into stdout for simplified stream handling + - Single long-lived reader thread (not per-command threads) + - Binary mode with manual decode to avoid text buffering issues + - Buffer offset tracking for clean per-command output extraction + - Single-flight execution with lock to prevent command interleaving + """ + + def __init__(self, timeout: int = 30): + self._timeout = timeout + self._process = None + self._alive = False + + # Single-flight execution lock + self._run_lock = threading.Lock() + + # Shared output buffer with synchronization + self._output_buffer = bytearray() + self._buffer_lock = threading.Lock() + self._buffer_condition = threading.Condition(self._buffer_lock) + + # Reader thread + self._reader_thread = None + self._stop_reader = False + + self._start_process() + + def __del__(self): + """Ensure OS processes and threads are cleaned up if the object is garbage collected.""" + try: + self.stop() + except Exception: + pass + + def _start_process(self): + """Start the shell process with clean configuration.""" + # default to bash + shell = os.environ.get("SHELL", "/bin/bash") + + # Configure shell for clean startup (no rc files) + if shell.endswith("bash"): + argv = [shell, "--noprofile", "--norc"] + elif shell.endswith("zsh"): + argv = [shell, "-f"] + else: + argv = [shell] + + # Start process with merged stderr->stdout, binary mode + self._process = subprocess.Popen( + argv, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Merge stderr into stdout + env={**os.environ, "PS1": "", "PS2": "", "PROMPT": ""}, + ) + + self._alive = True + self._stop_reader = False + + # Start long-lived reader thread + self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True) + self._reader_thread.start() + + def _reader_loop(self): + """Long-lived reader thread that continuously reads from stdout. + + This runs for the entire lifetime of the shell process, not per-command. + Reads fixed-size chunks (not readline!) to avoid blocking on newlines. + + Note: os.read() will block until data is available, which is fine for a + daemon thread. This approach is simpler and cross-platform (Windows compatible). + We avoid select() which doesn't work with file descriptors on Windows. + """ + READ_CHUNK_SIZE = 4096 + try: + fd = self._process.stdout.fileno() + while not self._stop_reader and self._process and self._process.poll() is None: + # Block until data is available (or EOF) + # This is safe in a daemon thread and works on all platforms + chunk = os.read(fd, READ_CHUNK_SIZE) + if not chunk: + # EOF - process died + break + + # Append to shared buffer and notify waiters + with self._buffer_condition: + self._output_buffer.extend(chunk) + self._buffer_condition.notify_all() + except Exception: + # Process died or other error + pass + finally: + with self._buffer_condition: + self._alive = False + self._buffer_condition.notify_all() + + def run(self, command: str, timeout: int | None = None) -> str: + """Execute a command in the persistent session. + + Args: + command: The command to execute + timeout: Optional timeout in seconds + + Returns: + Command output with exit code appended if non-zero + """ + # Single-flight execution - only one command at a time + with self._run_lock: + if not self._alive or not self._process or self._process.poll() is not None: + raise Exception("Shell session is not running") + + effective_timeout = timeout if timeout is not None else self._timeout + + # Generate unique sentinel hash + hash = f"{time.time_ns()}_{os.urandom(4).hex()}" + sentinel = f"__CMD_DONE__:{hash}:" + + # Record buffer position before command + with self._buffer_lock: + start_offset = len(self._output_buffer) + + # Write command with sentinel + try: + wrapped_command = f"{command}\n__EXIT_CODE=$?\nprintf '\\n{sentinel}%s\\n' \"$__EXIT_CODE\"\n" + self._process.stdin.write(wrapped_command.encode('utf-8')) + self._process.stdin.flush() + except (BrokenPipeError, OSError) as e: + self._alive = False + raise Exception(f"Failed to write to shell: {e}") + + # Wait for sentinel with timeout + deadline = time.time() + effective_timeout + sentinel_bytes = sentinel.encode('utf-8') + + while True: + with self._buffer_condition: + # Check if sentinel appeared after start_offset + buffer_view = bytes(self._output_buffer[start_offset:]) + if sentinel_bytes in buffer_view: + # Found sentinel. Extract output + output = buffer_view.decode('utf-8', errors='replace') + break + + # Check timeout + remaining = deadline - time.time() + if remaining <= 0: + # Timeout - kill session (not trustworthy after timeout) + self.stop() + raise TimeoutError(f"Command timed out after {effective_timeout} seconds") + + # Check if session died + if not self._alive: + raise Exception("Shell process died unexpectedly") + + # Wait for more output + self._buffer_condition.wait(timeout=min(remaining, 0.1)) + + # Prune the buffer to prevent memory leaks + # This is critical for long-lived sessions with many commands + with self._buffer_lock: + # Find the end of the sentinel line to safely truncate + sentinel_idx = self._output_buffer.find(sentinel_bytes, start_offset) + if sentinel_idx != -1: + # Find the newline after the sentinel + nl_idx = self._output_buffer.find(b'\n', sentinel_idx) + if nl_idx != -1: + # Delete everything up to and including the sentinel line + del self._output_buffer[:nl_idx + 1] + else: + # No newline found, just delete up to end of sentinel + del self._output_buffer[:sentinel_idx + len(sentinel_bytes)] + + # Parse output and extract exit code + exit_code = -1 + lines = output.split('\n') + filtered_lines = [] + + for line in lines: + if sentinel in line: + # Extract exit code from sentinel line + parts = line.split(':') + if len(parts) >= 3: + try: + exit_code = int(parts[2]) + except ValueError: + pass + # Don't include sentinel line in output + continue + filtered_lines.append(line) + + output = '\n'.join(filtered_lines).strip() + + # Append exit code if non-zero + if exit_code != 0: + output += f"\n\nExit code: {exit_code}" + + return output + + def stop(self): + """Stop the shell process and reader thread.""" + self._stop_reader = True + self._alive = False + + if self._process: + self._process.terminate() + try: + self._process.wait(timeout=1) + except subprocess.TimeoutExpired: + self._process.kill() + self._process.wait() + self._process = None + + if self._reader_thread and self._reader_thread.is_alive(): + self._reader_thread.join(timeout=1) + + def restart(self): + """Restart the shell session.""" + self.stop() + self._output_buffer.clear() + self._start_process() + + +@tool(context=True) +def shell_tool( + command: str, + timeout: int | None = None, + restart: bool = False, + tool_context: ToolContext = None +) -> str: + """ + Execute a shell command in a persistent shell session. + + The shell session preserves state across commands: + - Working directory (cd persists) + - Exported environment variables + - Shell variables + - Sourced shell state + + Uses the system default shell ($SHELL, defaulting to /bin/bash) with clean + startup configuration (--noprofile --norc for bash, -f for zsh). + + **Supported commands:** + - Standard shell commands + - Build/test commands + - Shell pipelines and normal non-interactive commands + + **Unsupported/unreliable:** + - Interactive programs: vim, less, top, nano + - REPLs: python, node, irb + - Password prompts or TTY-required programs + - Full-screen TUIs + - Background jobs that continue writing after command returns + + Args: + command: The shell command to execute + timeout: Optional timeout in seconds (default: 30) + restart: If True, restart the shell session before running the command + + Returns: + The command output, with exit code appended if non-zero + """ + agent = tool_context.agent + + # Handle restart without command - just recreate session and return + if restart and (not command or command.strip() == ""): + if agent in _sessions: + _sessions[agent].stop() + _sessions[agent] = ShellSession() + return "Shell session restarted" + + # Handle restart with command - stop old session and create fresh one + if restart: + if agent in _sessions: + _sessions[agent].stop() + _sessions[agent] = ShellSession() + + # Get or create session (normal case) + if agent not in _sessions: + _sessions[agent] = ShellSession() + + session = _sessions[agent] + + try: + return session.run(command, timeout=timeout) + except TimeoutError as e: + # Session is dead after timeout, recreate on next call + return f"Error: {str(e)}" + except Exception as e: + # Only restart if process actually died + if session._process is None or session._process.poll() is not None: + session.stop() + _sessions[agent] = ShellSession() + return f"Error: {str(e)}" diff --git a/strands-command/scripts/python/tests/__init__.py b/strands-command/scripts/python/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/strands-command/scripts/python/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/strands-command/scripts/python/tests/integ/__init__.py b/strands-command/scripts/python/tests/integ/__init__.py new file mode 100644 index 0000000..0ca287e --- /dev/null +++ b/strands-command/scripts/python/tests/integ/__init__.py @@ -0,0 +1 @@ +# Integration tests diff --git a/strands-command/scripts/python/tests/integ/test_shell_tool.py b/strands-command/scripts/python/tests/integ/test_shell_tool.py new file mode 100644 index 0000000..132dbc8 --- /dev/null +++ b/strands-command/scripts/python/tests/integ/test_shell_tool.py @@ -0,0 +1,191 @@ +"""Integration tests for shell_tool with Strands Agent.""" +import os +import pytest +import tempfile +from strands import Agent +from ...shell_tool import shell_tool + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp: + yield tmp + + +def assert_shell_tool_called(agent, expected_command_substring=None): + """Helper to validate shell_tool was called in agent messages. + + Args: + agent: The agent instance + expected_command_substring: Optional substring to check in the command parameter + + Note: Validates across ALL shell_tool calls in the conversation, not just the last one. + """ + # Check that there are messages (at least user message + assistant response) + assert len(agent.messages) >= 2, f"Expected at least 2 messages, got {len(agent.messages)}" + + # Collect all shell_tool calls + shell_tool_calls = [] + for msg in agent.messages: + content = msg.get('content', []) if isinstance(msg, dict) else [] + for block in content: + if isinstance(block, dict) and 'toolUse' in block: + tool_use = block['toolUse'] + if tool_use.get('name') == 'shell_tool': + command = tool_use.get('input', {}).get('command', '') + shell_tool_calls.append(command) + + # Validate shell_tool was called at least once + assert len(shell_tool_calls) > 0, f"shell_tool was not called. Messages: {agent.messages}" + + # If checking for specific command substring, validate at least one call contains it + if expected_command_substring: + matching_calls = [cmd for cmd in shell_tool_calls if expected_command_substring in cmd] + assert len(matching_calls) > 0, \ + f"Expected '{expected_command_substring}' in at least one shell_tool call. All calls: {shell_tool_calls}" + + +def test_basic_execution(): + """Test basic shell command execution through agent.""" + agent = Agent(tools=[shell_tool]) + result = agent("Run: echo 'hello'") + result_str = str(result) + print(f"\nBasic execution result: {result_str[:300]}") + assert result is not None + + # Validate shell_tool was called with echo command + assert_shell_tool_called(agent, expected_command_substring="echo") + + assert "hello" in result_str.lower() + + +def test_complex_multi_step_workflow(temp_dir): + """Test a complex multi-step workflow with state management.""" + agent = Agent(tools=[shell_tool]) + print(f"\n--- Complex Multi-Step Workflow Test (temp dir: {temp_dir}) ---") + + # Step 1: Change to temp directory and verify by creating a marker file + result1 = agent(f"Run: cd {temp_dir} && touch cd_marker.txt && pwd") + result1_str = str(result1) + print(f"Step 1 - Change to temp dir: {result1_str[:300]}") + + # Verify cd worked by checking marker file exists in temp_dir + marker_path = os.path.join(temp_dir, "cd_marker.txt") + assert os.path.exists(marker_path), f"cd failed - marker file not found at {marker_path}" + print(f"✓ Verified cd worked - marker file exists at {marker_path}") + + # Step 2: Create files with shell loops (should be in temp_dir due to persistence) + result2 = agent("""In the current directory, run: for i in 1 2 3; do echo "Line $i" > file$i.txt; done +Then list the files to confirm they were created.""") + result2_str = str(result2) + print(f"Step 2 - Create files: {result2_str[:300]}") + + # Verify files actually exist on disk + file1_path = os.path.join(temp_dir, "file1.txt") + file2_path = os.path.join(temp_dir, "file2.txt") + file3_path = os.path.join(temp_dir, "file3.txt") + assert os.path.exists(file1_path), f"file1.txt not found at {file1_path}" + assert os.path.exists(file2_path), f"file2.txt not found at {file2_path}" + assert os.path.exists(file3_path), f"file3.txt not found at {file3_path}" + print(f"✓ Verified all 3 files exist in {temp_dir}") + + # Verify file contents + with open(file2_path, 'r') as f: + content = f.read() + assert "Line 2" in content, f"file2.txt has wrong content: {content}" + print(f"✓ Verified file2.txt contains 'Line 2'") + + # Step 3: Use pipes and command substitution + result3 = agent("Run: cat file*.txt | wc -l") + result3_str = str(result3) + print(f"Step 3 - Count lines: {result3_str[:300]}") + assert "3" in result3_str, f"Expected 3 lines, got: {result3_str}" + + # Step 4: Use grep and conditionals + result4 = agent("""Run this command: if grep -q "Line 2" file2.txt; then echo "FOUND"; else echo "NOT FOUND"; fi""") + result4_str = str(result4) + print(f"Step 4 - Grep and conditional: {result4_str[:300]}") + assert "FOUND" in result4_str, f"Conditional failed: {result4_str}" + + # Step 5: Verify persistence - we should still be in the same directory + result5 = agent("Run: pwd") + result5_str = str(result5) + print(f"Step 5 - Verify pwd persistence: {result5_str[:300]}") + assert temp_dir in result5_str, f"Lost directory context: {result5_str}" + + # Validate shell_tool was used throughout (check last message for pwd command) + assert_shell_tool_called(agent, expected_command_substring="pwd") + + print("✓ All steps completed successfully") + + +def test_shell_functions_and_persistence(temp_dir): + """Test that shell functions persist across commands.""" + agent = Agent(tools=[shell_tool]) + print("\n--- Shell Functions Persistence Test ---") + + # Define a shell function that writes to a file + result1 = agent(f"""Run: cd {temp_dir} && greet() {{ echo "Hello, $1!" > greeting_$1.txt; }} && greet World""") + result1_str = str(result1) + print(f"Function definition result: {result1_str[:300]}") + + # Verify function wrote the file + greeting1_path = os.path.join(temp_dir, "greeting_World.txt") + assert os.path.exists(greeting1_path), f"greeting_World.txt not found at {greeting1_path}" + with open(greeting1_path, 'r') as f: + content = f.read().strip() + assert content == "Hello, World!", f"Wrong content: {content}" + print(f"✓ Verified function wrote 'Hello, World!' to {greeting1_path}") + + # Verify function persists in next command by calling it again + result2 = agent("Run: greet Testing") + result2_str = str(result2) + print(f"Function persistence result: {result2_str[:300]}") + + # Verify the persisted function wrote the new file + greeting2_path = os.path.join(temp_dir, "greeting_Testing.txt") + assert os.path.exists(greeting2_path), f"greeting_Testing.txt not found - function didn't persist" + with open(greeting2_path, 'r') as f: + content = f.read().strip() + assert content == "Hello, Testing!", f"Wrong content: {content}" + print(f"✓ Verified persisted function wrote 'Hello, Testing!' to {greeting2_path}") + + # Validate shell_tool was called with greet command + assert_shell_tool_called(agent, expected_command_substring="greet") + + print("✓ Shell function persisted across commands") + + +def test_error_handling_and_recovery(temp_dir): + """Test that agent can recover from errors and continue.""" + agent = Agent(tools=[shell_tool]) + print("\n--- Error Handling and Recovery Test ---") + + # Run a failing command that tries to write to a non-existent directory + bad_path = "/this/does/not/exist/test.txt" + result1 = agent(f"Run: echo 'test' > {bad_path}") + result1_str = str(result1) + print(f"Error result: {result1_str[:300]}") + + # Verify the file was NOT created (command failed) + assert not os.path.exists(bad_path), f"File should not exist - command should have failed" + print(f"✓ Verified command failed - no file created at {bad_path}") + + # Verify shell still works after error by creating a recovery file + recovery_path = os.path.join(temp_dir, "recovery_success.txt") + result2 = agent(f"Run: echo 'RECOVERED' > {recovery_path}") + result2_str = str(result2) + print(f"Recovery result: {result2_str[:300]}") + + # Verify recovery by checking the file exists and has correct content + assert os.path.exists(recovery_path), f"Recovery failed - file not created at {recovery_path}" + with open(recovery_path, 'r') as f: + content = f.read().strip() + assert content == "RECOVERED", f"Wrong content in recovery file: {content}" + print(f"✓ Verified shell recovered - created file at {recovery_path} with correct content") + + # Validate shell_tool was called + assert_shell_tool_called(agent, expected_command_substring="echo") + + print("✓ Successfully recovered from error") diff --git a/strands-command/scripts/python/tests/requirements.txt b/strands-command/scripts/python/tests/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/strands-command/scripts/python/tests/test_shell_tool.py b/strands-command/scripts/python/tests/test_shell_tool.py new file mode 100644 index 0000000..711a0e8 --- /dev/null +++ b/strands-command/scripts/python/tests/test_shell_tool.py @@ -0,0 +1,354 @@ +"""Tests for the persistent shell tool.""" +import pytest +from unittest.mock import Mock +from strands.types.tools import ToolContext +from ..shell_tool import ShellSession, shell_tool + + +"""Tests for ShellSession class.""" + +def test_basic_command(): + """Test basic command execution.""" + session = ShellSession() + try: + output = session.run("echo hello") + assert "hello" in output + finally: + session.stop() + +def test_exit_code_success(): + """Test successful command exit code.""" + session = ShellSession() + try: + output = session.run("true") + # Exit code 0 should not be appended + assert "Exit code:" not in output + finally: + session.stop() + +def test_exit_code_failure(): + """Test failed command exit code.""" + session = ShellSession() + try: + output = session.run("false") + assert "Exit code: 1" in output + finally: + session.stop() + +def test_multiline_output(): + """Test command with multiple output lines.""" + session = ShellSession() + try: + output = session.run("echo line1; echo line2; echo line3") + assert "line1" in output + assert "line2" in output + assert "line3" in output + finally: + session.stop() + +def test_cd_persistence(): + """Test that cd command persists across commands.""" + session = ShellSession() + try: + # Create a temp directory + session.run("mkdir -p /tmp/shell_test_$$") + session.run("cd /tmp/shell_test_$$") + + # Check we're in the right directory + output = session.run("pwd") + assert "shell_test" in output + + # Verify persistence + output2 = session.run("pwd") + assert output.strip() == output2.strip() + finally: + session.stop() + +def test_env_var_persistence(): + """Test that exported environment variables persist.""" + session = ShellSession() + try: + session.run("export TEST_VAR=hello123") + output = session.run("echo $TEST_VAR") + assert "hello123" in output + finally: + session.stop() + +def test_shell_variable_persistence(): + """Test that shell variables persist.""" + session = ShellSession() + try: + session.run("MY_VAR=testing456") + output = session.run("echo $MY_VAR") + assert "testing456" in output + finally: + session.stop() + +def test_stderr_merged_into_stdout(): + """Test that stderr is merged into output.""" + session = ShellSession() + try: + # Command that writes to stderr + output = session.run("echo error message >&2") + assert "error message" in output + finally: + session.stop() + +def test_no_newline_output(): + """Test command that produces output without trailing newline.""" + session = ShellSession() + try: + output = session.run("printf 'no newline'") + assert "no newline" in output + finally: + session.stop() + +def test_large_output(): + """Test command with large output.""" + session = ShellSession() + try: + # Generate ~10KB of output + output = session.run("for i in {1..200}; do echo 'Line number '$i; done") + assert "Line number 1" in output + assert "Line number 200" in output + finally: + session.stop() + +def test_timeout(): + """Test command timeout handling.""" + session = ShellSession(timeout=1) + try: + with pytest.raises(TimeoutError): + session.run("sleep 10") + + # Session should be dead after timeout + assert not session._alive + finally: + # Clean up even if assertion fails + if session._alive: + session.stop() + +def test_sequential_commands(): + """Test multiple sequential commands.""" + session = ShellSession() + try: + session.run("echo first") + session.run("echo second") + output = session.run("echo third") + assert "third" in output + finally: + session.stop() + +def test_pipe_commands(): + """Test commands with pipes.""" + session = ShellSession() + try: + output = session.run("echo 'hello world' | grep hello") + assert "hello world" in output + finally: + session.stop() + +def test_command_substitution(): + """Test command substitution.""" + session = ShellSession() + try: + output = session.run("echo $(echo nested)") + assert "nested" in output + finally: + session.stop() + +def test_exit_code_propagation(): + """Test that non-zero exit codes are properly captured.""" + session = ShellSession() + try: + output = session.run("ls /nonexistent 2>&1") + assert "Exit code:" in output + # ls should fail with non-zero exit code + assert "Exit code: 0" not in output + finally: + session.stop() + +def test_restart(): + """Test session restart.""" + session = ShellSession() + try: + session.run("export TEST_VAR=before") + session.restart() + # Variable should be gone after restart + output = session.run("echo ${TEST_VAR:-empty}") + assert "empty" in output + finally: + session.stop() + +def test_special_characters_in_output(): + """Test handling of special characters.""" + session = ShellSession() + try: + output = session.run("echo '$HOME' '\\n' '\\t' '|' '&'") + assert "$HOME" in output + finally: + session.stop() + + +"""Tests for shell_tool function.""" + +def create_mock_context(): + """Create a mock tool context with properly configured agent.""" + # Create a simple object that can have attributes set on it + class MockAgent: + pass + + mock_agent = MockAgent() + mock_context = Mock(spec=ToolContext) + mock_context.agent = mock_agent + return mock_context + +def test_tool_basic_usage(): + """Test basic tool usage.""" + context = create_mock_context() + output = shell_tool("echo test", tool_context=context) + assert "test" in output + +def test_tool_creates_session(): + """Test that tool creates session in registry.""" + from ..shell_tool import _sessions + context = create_mock_context() + shell_tool("echo test", tool_context=context) + assert context.agent in _sessions + +def test_tool_reuses_session(): + """Test that tool reuses existing session.""" + context = create_mock_context() + shell_tool("export VAR=value", tool_context=context) + output = shell_tool("echo $VAR", tool_context=context) + assert "value" in output + +def test_tool_restart_flag(): + """Test tool restart functionality.""" + context = create_mock_context() + shell_tool("export VAR=before", tool_context=context) + output = shell_tool("echo start", restart=True, tool_context=context) + assert "start" in output + + # Variable should be gone after restart + output = shell_tool("echo ${VAR:-gone}", tool_context=context) + assert "gone" in output + +def test_tool_restart_only(): + """Test restarting without command.""" + context = create_mock_context() + shell_tool("echo test", tool_context=context) + output = shell_tool("", restart=True, tool_context=context) + assert "restarted" in output.lower() + +def test_tool_timeout_parameter(): + """Test custom timeout parameter.""" + context = create_mock_context() + output = shell_tool("sleep 0.1", timeout=5, tool_context=context) + # Should complete successfully + assert "Error" not in output + +def test_tool_timeout_error(): + """Test timeout error handling.""" + context = create_mock_context() + output = shell_tool("sleep 10", timeout=1, tool_context=context) + assert "timeout" in output.lower() or "Error" in output + +def test_tool_exit_code_in_output(): + """Test that non-zero exit codes appear in output.""" + context = create_mock_context() + output = shell_tool("false", tool_context=context) + assert "Exit code: 1" in output + +def test_tool_persistence_across_calls(): + """Test that state persists across multiple tool calls.""" + context = create_mock_context() + shell_tool("cd /tmp", tool_context=context) + shell_tool("export MY_VAR=persistent", tool_context=context) + output = shell_tool("pwd; echo $MY_VAR", tool_context=context) + assert "/tmp" in output + assert "persistent" in output + +def test_tool_session_cleanup_on_error(): + """Test that session is recreated after fatal error.""" + from ..shell_tool import _sessions + context = create_mock_context() + + # First call succeeds + shell_tool("echo first", tool_context=context) + session = _sessions[context.agent] + + # Force kill the process + session._process.kill() + session._process.wait() + + # Next call should recreate session + output = shell_tool("echo recovered", tool_context=context) + assert "recovered" in output or "Error" in output + + +"""Tests to verify architectural properties.""" + +def test_no_readline_dependency(): + """Verify that output without newlines works correctly.""" + session = ShellSession() + try: + # This would fail if using readline() + output = session.run("printf 'line1'; printf 'line2'") + assert "line1" in output and "line2" in output + finally: + session.stop() + +def test_sentinel_uniqueness(): + """Test that concurrent commands would use unique sentinels.""" + session = ShellSession() + try: + # Run multiple commands - each should have unique sentinel + outputs = [] + for i in range(5): + output = session.run(f"echo test{i}") + outputs.append(output) + assert f"test{i}" in output + + # All outputs should be distinct + assert len(set(outputs)) == len(outputs) + finally: + session.stop() + +def test_binary_mode_handling(): + """Test that binary mode handles various encodings.""" + session = ShellSession() + try: + # Test with UTF-8 characters + output = session.run("echo 'Hello 世界'") + # Should decode with replacement, not crash + assert isinstance(output, str) + finally: + session.stop() + +def test_buffer_offset_isolation(): + """Test that commands don't see each other's output.""" + session = ShellSession() + try: + output1 = session.run("echo first") + output2 = session.run("echo second") + + # Second command should not include first command's output + assert "first" not in output2 + assert "second" in output2 + finally: + session.stop() + +def test_merged_stderr_stdout(): + """Test that stderr and stdout are properly merged.""" + session = ShellSession() + try: + # Command with interleaved stdout and stderr + output = session.run("echo out1; echo err1 >&2; echo out2; echo err2 >&2") + # All output should be present + assert "out1" in output + assert "err1" in output + assert "out2" in output + assert "err2" in output + finally: + session.stop()