diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ebead3b7d..3284d9acf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -777,49 +777,66 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( - BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages) - ) - messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages + current_messages: Messages | None = messages - agent_result: AgentResult | None = None - try: - yield InitEventLoopEvent() + while current_messages is not None: + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=current_messages) + ) + current_messages = ( + before_invocation_event.messages if before_invocation_event.messages is not None else current_messages + ) - await self._append_messages(*messages) + agent_result: AgentResult | None = None + try: + yield InitEventLoopEvent() - structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model, - structured_output_prompt=structured_output_prompt or self._structured_output_prompt, - ) + await self._append_messages(*current_messages) - # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state, structured_output_context) - async for event in events: - # Signal from the model provider that the message sent by the user should be redacted, - # likely due to a guardrail. - if ( - isinstance(event, ModelStreamChunkEvent) - and event.chunk - and event.chunk.get("redactContent") - and event.chunk["redactContent"].get("redactUserContentMessage") - ): - self.messages[-1]["content"] = self._redact_user_content( - self.messages[-1]["content"], str(event.chunk["redactContent"]["redactUserContentMessage"]) - ) - if self._session_manager: - self._session_manager.redact_latest_message(self.messages[-1], self) - yield event + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, + ) - # Capture the result from the final event if available - if isinstance(event, EventLoopStopEvent): - agent_result = AgentResult(*event["stop"]) + # Execute the event loop cycle with retry logic for context limits + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) + async for event in events: + # Signal from the model provider that the message sent by the user should be redacted, + # likely due to a guardrail. + if ( + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") + ): + self.messages[-1]["content"] = self._redact_user_content( + self.messages[-1]["content"], + str(event.chunk["redactContent"]["redactUserContentMessage"]), + ) + if self._session_manager: + self._session_manager.redact_latest_message(self.messages[-1], self) + yield event + + # Capture the result from the final event if available + if isinstance(event, EventLoopStopEvent): + agent_result = AgentResult(*event["stop"]) - finally: - self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async( - AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) - ) + finally: + self.conversation_manager.apply_management(self) + after_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) + + # Convert resume input to messages for next iteration, or None to stop + if after_invocation_event.resume is not None: + logger.debug("resume= | hook requested agent resume with new input") + # If in interrupt state, process interrupt responses before continuing. + # This mirrors the _interrupt_state.resume() call in stream_async and will + # raise TypeError if the resume input is not valid interrupt responses. + self._interrupt_state.resume(after_invocation_event.resume) + current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume) + else: + current_messages = None async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8d3e5d280..9186e0e70 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult +from ..types.agent import AgentInput from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason @@ -78,6 +79,13 @@ class AfterInvocationEvent(HookEvent): - Agent.stream_async - Agent.structured_output + Resume: + When ``resume`` is set to a non-None value by a hook callback, the agent will + automatically re-invoke itself with the provided input. This enables hooks to + implement autonomous looping patterns where the agent continues processing + based on its previous result. The resume triggers a full new invocation cycle + including ``BeforeInvocationEvent``. + Attributes: invocation_state: State and configuration passed through the agent invocation. This can include shared context for multi-agent coordination, request tracking, @@ -85,10 +93,17 @@ class AfterInvocationEvent(HookEvent): result: The result of the agent invocation, if available. This will be None when invoked from structured_output methods, as those return typed output directly rather than AgentResult. + resume: When set to a non-None agent input by a hook callback, the agent will + re-invoke itself with this input. The value can be any valid AgentInput + (str, content blocks, messages, etc.). Defaults to None (no resume). """ invocation_state: dict[str, Any] = field(default_factory=dict) result: "AgentResult | None" = None + resume: AgentInput = None + + def _can_write(self, name: str) -> bool: + return name == "resume" @property def should_reverse_callbacks(self) -> bool: diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index de551d137..0e03fbbcd 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -230,3 +230,33 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes """Test that BeforeInvocationEvent.agent is not writable.""" with pytest.raises(AttributeError, match="Property agent is not writable"): start_request_event_with_messages.agent = Mock() + + +def test_after_invocation_event_resume_defaults_to_none(agent): + """Test that AfterInvocationEvent.resume defaults to None.""" + event = AfterInvocationEvent(agent=agent, result=None) + assert event.resume is None + + +def test_after_invocation_event_resume_is_writable(agent): + """Test that AfterInvocationEvent.resume can be set by hooks.""" + event = AfterInvocationEvent(agent=agent, result=None) + event.resume = "continue with this input" + assert event.resume == "continue with this input" + + +def test_after_invocation_event_resume_accepts_various_input_types(agent): + """Test that resume accepts all AgentInput types.""" + event = AfterInvocationEvent(agent=agent, result=None) + + # String input + event.resume = "hello" + assert event.resume == "hello" + + # Content block list + event.resume = [{"text": "hello"}] + assert event.resume == [{"text": "hello"}] + + # None to stop + event.resume = None + assert event.resume is None diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 4397b9628..1da245d70 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -694,3 +694,330 @@ async def capture_messages_hook(event: BeforeInvocationEvent): # structured_output_async uses deprecated path that doesn't pass messages assert received_messages is None + + +def test_after_invocation_resume_triggers_new_invocation(): + """Test that setting resume on AfterInvocationEvent re-invokes the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + ) + + resume_count = 0 + + async def resume_once(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count == 0: + resume_count += 1 + event.resume = "continue" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + result = agent("start") + + # Agent should have been invoked twice + assert resume_count == 1 + assert result.message["content"][0]["text"] == "Second response" + # 4 messages: user1, assistant1, user2 (resume), assistant2 + assert len(agent.messages) == 4 + assert agent.messages[0]["content"][0]["text"] == "start" + assert agent.messages[2]["content"][0]["text"] == "continue" + + +def test_after_invocation_resume_none_does_not_loop(): + """Test that resume=None (default) does not re-invoke the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Only response"}]}, + ] + ) + + call_count = 0 + + async def no_resume(event: AfterInvocationEvent): + nonlocal call_count + call_count += 1 + # Don't set resume - should remain None + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, no_resume) + + result = agent("hello") + + assert call_count == 1 + assert result.message["content"][0]["text"] == "Only response" + + +def test_after_invocation_resume_fires_before_invocation_event(): + """Test that resume triggers BeforeInvocationEvent on each iteration.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + ] + ) + + before_invocation_count = 0 + after_invocation_count = 0 + + async def count_before(event: BeforeInvocationEvent): + nonlocal before_invocation_count + before_invocation_count += 1 + + async def resume_once(event: AfterInvocationEvent): + nonlocal after_invocation_count + after_invocation_count += 1 + if after_invocation_count == 1: + event.resume = "next" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, count_before) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + agent("start") + + # BeforeInvocationEvent should fire for both the initial and resumed invocation + assert before_invocation_count == 2 + assert after_invocation_count == 2 + + +def test_after_invocation_resume_multiple_times(): + """Test that resume can chain multiple re-invocations.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + ) + + resume_count = 0 + + async def resume_twice(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count < 2: + resume_count += 1 + event.resume = f"iteration {resume_count + 1}" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_twice) + + result = agent("iteration 1") + + assert resume_count == 2 + assert result.message["content"][0]["text"] == "Response 3" + # 6 messages: 3 user + 3 assistant + assert len(agent.messages) == 6 + + +def test_after_invocation_resume_handles_interrupt_with_responses(): + """Test that a hook can handle an interrupt by resuming with interrupt responses.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (after interrupt resume): model gives final response + {"role": "assistant", "content": [{"text": "Completed after interrupt"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + """Interrupt before tool execution; returns stored response on second call.""" + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need human approval") + + async def handle_interrupt_via_resume(event: AfterInvocationEvent): + """Hook that automatically handles interrupts by resuming with responses.""" + if event.result and event.result.stop_reason == "interrupt": + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, handle_interrupt_via_resume) + + result = agent("do something") + + # The hook handled the interrupt automatically — agent completed normally + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Completed after interrupt" + # Interrupt state should be cleared after successful resume + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_with_invalid_input_during_interrupt(): + """Test that resuming with non-interrupt input while interrupt is active raises TypeError.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + async def resume_with_bad_input(event: AfterInvocationEvent): + """Hook that incorrectly tries to resume with a plain string during interrupt.""" + if event.result and event.result.stop_reason == "interrupt": + event.resume = "this is wrong" + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, resume_with_bad_input) + + with pytest.raises(TypeError, match="must resume from interrupt with list of interruptResponse's"): + agent("do something") + + +def test_after_invocation_resume_interrupt_without_resume_returns_to_caller(): + """Test that an interrupt without resume set returns the interrupt to the caller.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (caller resumes manually): final response + {"role": "assistant", "content": [{"text": "Done after manual resume"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + # First call: hits interrupt, no hook handles it, returns to caller + result = agent("do something") + assert result.stop_reason == "interrupt" + assert len(result.interrupts) == 1 + assert result.interrupts[0].name == "approval_needed" + assert agent._interrupt_state.activated is True + + # Caller manually resumes with interrupt responses + interrupt_id = result.interrupts[0].id + result = agent([{"interruptResponse": {"interruptId": interrupt_id, "response": "yes"}}]) + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Done after manual resume" + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_interrupt_during_resumed_invocation(): + """Test that an interrupt during a resumed invocation can be handled by the hook.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: simple text response (no tool call) + {"role": "assistant", "content": [{"text": "First response"}]}, + # Second invocation (resumed): triggers a tool call which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Third invocation (after interrupt handled via resume): final response + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + ) + + invocation_count = 0 + + async def resume_hook(event: AfterInvocationEvent): + """Resume with new input on first call, handle interrupt on second.""" + nonlocal invocation_count + invocation_count += 1 + if invocation_count == 1: + # First invocation done, resume with new input + event.resume = "continue" + elif event.result and event.result.stop_reason == "interrupt": + # Second invocation hit interrupt, handle it + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(AfterInvocationEvent, resume_hook) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + result = agent("start") + + # All three invocations happened within a single agent call + assert invocation_count == 3 + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Final response" + assert agent._interrupt_state.activated is False