Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def __init__(
self.record_direct_tool_call = record_direct_tool_call
self.load_tools_from_directory = load_tools_from_directory

# Create internal cancel signal for graceful cancellation using threading.Event
self._cancel_signal = threading.Event()

self.tool_registry = ToolRegistry()

# Process tool list if provided
Expand Down Expand Up @@ -327,6 +330,37 @@ def __init__(

self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))

def cancel(self) -> None:
"""Cancel the currently running agent invocation.

This method is thread-safe and can be called from any context
(e.g., another thread, web request handler, background task).

The agent will stop gracefully at the next checkpoint:
- During model response streaming
- Before tool execution

The agent will return a result with stop_reason="cancelled".

Example:
```python
agent = Agent(model=model)

# Start agent in background
task = asyncio.create_task(agent.invoke_async("Hello"))

# Cancel from another context
agent.cancel()

result = await task
assert result.stop_reason == "cancelled"
```

Note:
Multiple calls to cancel() are safe and idempotent.
"""
self._cancel_signal.set()

@property
def system_prompt(self) -> str | None:
"""Get the system prompt as a string for backwards compatibility.
Expand Down Expand Up @@ -724,6 +758,9 @@ async def stream_async(
if invocation_state is not None:
merged_state = invocation_state

# Add cancel signal to invocation state for streaming access
merged_state["cancel_signal"] = self._cancel_signal

callback_handler = self.callback_handler
if kwargs:
callback_handler = kwargs.get("callback_handler", self.callback_handler)
Expand Down Expand Up @@ -756,6 +793,9 @@ async def stream_async(
raise

finally:
# Clear cancel signal to allow agent reuse after cancellation
self._cancel_signal.clear()

if self._invocation_lock.locked():
self._invocation_lock.release()

Expand Down
41 changes: 41 additions & 0 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,47 @@ async def _handle_tool_execution(
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids]

interrupts = []

# Check for cancellation before tool execution
# Add tool_result for each tool_use to maintain valid conversation state
if agent._cancel_signal.is_set():
logger.debug("tool_count=<%d> | cancellation detected before tool execution", len(tool_uses))

# Create cancellation tool_result for each tool_use to avoid invalid message state
# (tool_use without tool_result would be rejected on next invocation)
for tool_use in tool_uses:
cancel_result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": "Tool execution cancelled"}],
}
tool_results.append(cancel_result)

# Add tool results message to conversation if any tools were cancelled
cancelled_tool_result_message: Message | None = None
if tool_results:
_cancelled_msg: Message = {
"role": "user",
"content": [{"toolResult": result} for result in tool_results],
}
cancelled_tool_result_message = _cancelled_msg
agent.messages.append(_cancelled_msg)
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=_cancelled_msg))
yield ToolResultMessageEvent(message=_cancelled_msg)

agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
yield EventLoopStopEvent(
"cancelled",
message,
agent.event_loop_metrics,
invocation_state["request_state"],
)
if cycle_span:
tracer.end_event_loop_cycle_span(
span=cycle_span, message=message, tool_result_message=cancelled_tool_result_message
)
return

tool_events = agent.tool_executor._execute(
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context
)
Expand Down
22 changes: 20 additions & 2 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import threading
import time
import warnings
from collections.abc import AsyncGenerator, AsyncIterable
Expand Down Expand Up @@ -368,13 +369,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non


async def process_stream(
chunks: AsyncIterable[StreamEvent], start_time: float | None = None
chunks: AsyncIterable[StreamEvent],
start_time: float | None = None,
cancel_signal: threading.Event | None = None,
) -> AsyncGenerator[TypedEvent, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
chunks: The chunks of the response stream from the model.
start_time: Time when the model request is initiated
cancel_signal: Optional threading.Event to check for cancellation during streaming.

Yields:
The reason for stopping, the constructed message, and the usage metrics.
Expand All @@ -395,6 +399,19 @@ async def process_stream(
metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0)

async for chunk in chunks:
# Check for cancellation during stream processing
if cancel_signal and cancel_signal.is_set():
logger.debug("cancellation detected during stream processing")
# Return cancelled stop reason with cancellation message
# The incomplete message in state["message"] is discarded and never added to agent.messages
yield ModelStopReason(
stop_reason="cancelled",
message={"role": "assistant", "content": [{"text": "Cancelled by user"}]},
usage=usage,
metrics=metrics,
)
return

# Track first byte time when we get first content
if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk):
first_byte_time = time.time()
Expand Down Expand Up @@ -463,5 +480,6 @@ async def stream_messages(
invocation_state=invocation_state,
)

async for event in process_stream(chunks, start_time):
cancel_signal = invocation_state.get("cancel_signal") if invocation_state else None
async for event in process_stream(chunks, start_time, cancel_signal):
yield event
4 changes: 2 additions & 2 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
else:
state_changed = current_state_version != last_synced.get("state_version")
internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version")
conversation_manager_state_changed = (
current_conversation_manager_state != last_synced.get("conversation_manager_state")
conversation_manager_state_changed = current_conversation_manager_state != last_synced.get(
"conversation_manager_state"
)

if not state_changed and not internal_state_changed and not conversation_manager_state_changed:
Expand Down
2 changes: 2 additions & 0 deletions src/strands/types/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
timeToFirstByteMs: int


StopReason = Literal[

Check warning on line 39 in src/strands/types/event_loop.py

View workflow job for this annotation

GitHub Actions / check-api

StopReason

Attribute value was changed: `Literal['content_filtered', 'end_turn', 'guardrail_intervened', 'interrupt', 'max_tokens', 'stop_sequence', 'tool_use']` -> `Literal['cancelled', 'content_filtered', 'end_turn', 'guardrail_intervened', 'interrupt', 'max_tokens', 'stop_sequence', 'tool_use']`
"cancelled",
"content_filtered",
"end_turn",
"guardrail_intervened",
Expand All @@ -47,6 +48,7 @@
]
"""Reason for the model ending its response generation.

- "cancelled": Agent execution was cancelled via agent.cancel()
- "content_filtered": Content was filtered due to policy violation
- "end_turn": Normal completion of the response
- "guardrail_intervened": Guardrail system intervened
Expand Down
9 changes: 5 additions & 4 deletions tests/strands/agent/hooks/test_agent_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def mock_sleep():
"event_loop_cycle_span": ANY,
"event_loop_cycle_trace": ANY,
"request_state": {},
"cancel_signal": ANY,
}


Expand Down Expand Up @@ -116,7 +117,7 @@ async def test_stream_e2e_success(alist):
tru_events = await alist(stream)
exp_events = [
# Cycle 1: Initialize and invoke normal_tool
{"arg1": 1013, "init_event_loop": True},
{"arg1": 1013, "init_event_loop": True, "cancel_signal": ANY},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
Expand Down Expand Up @@ -354,7 +355,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep):

tru_events = await alist(stream)
exp_events = [
{"arg1": 1013, "init_event_loop": True},
{"arg1": 1013, "init_event_loop": True, "cancel_signal": ANY},
{"start": True},
{"start_event_loop": True},
{"event_loop_throttled_delay": 4, **throttle_props},
Expand Down Expand Up @@ -413,7 +414,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist):

tru_events = await alist(stream)
exp_events = [
{"init_event_loop": True},
{"init_event_loop": True, "cancel_signal": ANY},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
Expand Down Expand Up @@ -503,7 +504,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end(
}

exp_events = [
{"init_event_loop": True, "arg1": 1013},
{"init_event_loop": True, "arg1": 1013, "cancel_signal": ANY},
{"start": True},
{"start_event_loop": True},
{"event_loop_throttled_delay": 4, **common_props},
Expand Down
10 changes: 7 additions & 3 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):

agent("test")
assert callback_handler.call_args_list == [
unittest.mock.call(init_event_loop=True),
unittest.mock.call(init_event_loop=True, cancel_signal=agent._cancel_signal),
unittest.mock.call(start=True),
unittest.mock.call(start_event_loop=True),
unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}),
Expand All @@ -729,6 +729,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
event_loop_cycle_span=unittest.mock.ANY,
event_loop_cycle_trace=unittest.mock.ANY,
request_state={},
cancel_signal=agent._cancel_signal,
),
unittest.mock.call(event={"contentBlockStop": {}}),
unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
Expand All @@ -742,6 +743,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
reasoning=True,
reasoningText="value",
request_state={},
cancel_signal=agent._cancel_signal,
),
unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}),
unittest.mock.call(
Expand All @@ -753,6 +755,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
reasoning=True,
reasoning_signature="value",
request_state={},
cancel_signal=agent._cancel_signal,
),
unittest.mock.call(event={"contentBlockStop": {}}),
unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
Expand All @@ -765,6 +768,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
event_loop_cycle_span=unittest.mock.ANY,
event_loop_cycle_trace=unittest.mock.ANY,
request_state={},
cancel_signal=agent._cancel_signal,
),
unittest.mock.call(event={"contentBlockStop": {}}),
unittest.mock.call(
Expand Down Expand Up @@ -1075,7 +1079,7 @@ async def test_event_loop(*args, **kwargs):

tru_events = await alist(stream)
exp_events = [
{"init_event_loop": True, "callback_handler": mock_callback},
{"init_event_loop": True, "callback_handler": mock_callback, "cancel_signal": agent._cancel_signal},
{"data": "First chunk"},
{"data": "Second chunk"},
{"complete": True, "data": "Final chunk"},
Expand Down Expand Up @@ -1190,7 +1194,7 @@ async def check_invocation_state(**kwargs):

tru_events = await alist(stream)
exp_events = [
{"init_event_loop": True, "some_value": "a_value"},
{"init_event_loop": True, "some_value": "a_value", "cancel_signal": agent._cancel_signal},
{
"result": AgentResult(
stop_reason="stop",
Expand Down
Loading
Loading