Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,15 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict
return {"additionalModelRequestFields": additional_fields}

def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
"""Inject a cache point at the end of the last assistant message.
"""Inject a cache point at the end of the last user message.

Args:
messages: List of messages to inject cache point into (modified in place).
"""
if not messages:
return

last_assistant_idx: int | None = None
last_user_idx: int | None = None
for msg_idx, msg in enumerate(messages):
content = msg.get("content", [])
for block_idx, block in reversed(list(enumerate(content))):
Expand All @@ -358,12 +358,12 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
msg_idx,
block_idx,
)
if msg.get("role") == "assistant":
last_assistant_idx = msg_idx
if msg.get("role") == "user":
last_user_idx = msg_idx

if last_assistant_idx is not None and messages[last_assistant_idx].get("content"):
messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}})
logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx)
if last_user_idx is not None and messages[last_user_idx].get("content"):
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)

def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
"""Find the index of the last user message containing text or image content.
Expand Down
75 changes: 57 additions & 18 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,8 +2597,8 @@ def test_cache_strategy_none_for_non_claude(bedrock_client):
assert model._cache_strategy is None


def test_inject_cache_point_adds_to_last_assistant(bedrock_client):
"""Test that _inject_cache_point adds cache point to last assistant message."""
def test_inject_cache_point_adds_to_last_user(bedrock_client):
"""Test that _inject_cache_point adds cache point to last user message."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
)
Expand All @@ -2611,13 +2611,14 @@ def test_inject_cache_point_adds_to_last_assistant(bedrock_client):

model._inject_cache_point(cleaned_messages)

assert len(cleaned_messages[1]["content"]) == 2
assert "cachePoint" in cleaned_messages[1]["content"][-1]
assert cleaned_messages[1]["content"][-1]["cachePoint"]["type"] == "default"
assert len(cleaned_messages[2]["content"]) == 2
assert "cachePoint" in cleaned_messages[2]["content"][-1]
assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default"
assert len(cleaned_messages[1]["content"]) == 1


def test_inject_cache_point_no_assistant_message(bedrock_client):
"""Test that _inject_cache_point does nothing when no assistant message exists."""
def test_inject_cache_point_single_user_message(bedrock_client):
"""Test that _inject_cache_point adds cache point to single user message."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
)
Expand All @@ -2629,6 +2630,39 @@ def test_inject_cache_point_no_assistant_message(bedrock_client):
model._inject_cache_point(cleaned_messages)

assert len(cleaned_messages) == 1
assert len(cleaned_messages[0]["content"]) == 2
assert "cachePoint" in cleaned_messages[0]["content"][-1]


def test_inject_cache_point_empty_messages(bedrock_client):
"""Test that _inject_cache_point handles empty messages list."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
)

cleaned_messages = []
model._inject_cache_point(cleaned_messages)

assert cleaned_messages == []


def test_inject_cache_point_with_tool_result_last_user(bedrock_client):
"""Test that cache point is added to last user message even when it contains toolResult."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Use the tool"}]},
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {}}}]},
{"role": "user", "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}]},
]

model._inject_cache_point(cleaned_messages)

assert len(cleaned_messages[2]["content"]) == 2
assert "cachePoint" in cleaned_messages[2]["content"][-1]
assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default"
assert len(cleaned_messages[0]["content"]) == 1


Expand All @@ -2643,6 +2677,8 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client):

formatted = model._format_bedrock_messages(messages)

assert len(formatted[0]["content"]) == 1
assert "cachePoint" not in formatted[0]["content"][0]
assert len(formatted[1]["content"]) == 1
assert "cachePoint" not in formatted[1]["content"][0]

Expand All @@ -2664,8 +2700,8 @@ def test_format_bedrock_messages_does_not_mutate_original(bedrock_client):
formatted = model._format_bedrock_messages(original_messages)

assert original_messages == messages_before
assert "cachePoint" not in original_messages[1]["content"][-1]
assert "cachePoint" in formatted[1]["content"][-1]
assert "cachePoint" not in original_messages[2]["content"][-1]
assert "cachePoint" in formatted[2]["content"][-1]


def test_inject_cache_point_strips_existing_cache_points(bedrock_client):
Expand All @@ -2685,12 +2721,13 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client):
model._inject_cache_point(cleaned_messages)

# All old cache points should be stripped
assert len(cleaned_messages[0]["content"]) == 1 # user: only text
assert len(cleaned_messages[0]["content"]) == 1 # first user: only text
assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text
assert len(cleaned_messages[3]["content"]) == 1 # last assistant: only text

# New cache point should be at end of last assistant message
assert len(cleaned_messages[3]["content"]) == 2
assert "cachePoint" in cleaned_messages[3]["content"][-1]
# New cache point should be at end of last user message
assert len(cleaned_messages[2]["content"]) == 2
assert "cachePoint" in cleaned_messages[2]["content"][-1]


def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client):
Expand All @@ -2707,9 +2744,10 @@ def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client)

formatted = model._format_bedrock_messages(messages)

assert len(formatted[1]["content"]) == 2
assert "cachePoint" in formatted[1]["content"][-1]
assert formatted[1]["content"][-1]["cachePoint"]["type"] == "default"
assert len(formatted[0]["content"]) == 2
assert "cachePoint" in formatted[0]["content"][-1]
assert formatted[0]["content"][-1]["cachePoint"]["type"] == "default"
assert len(formatted[1]["content"]) == 1


def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client):
Expand All @@ -2725,8 +2763,9 @@ def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedro

formatted = model._format_bedrock_messages(messages)

assert len(formatted[1]["content"]) == 2
assert "cachePoint" in formatted[1]["content"][-1]
assert len(formatted[0]["content"]) == 2
assert "cachePoint" in formatted[0]["content"][-1]
assert len(formatted[1]["content"]) == 1


def test_find_last_user_text_message_index_no_user_messages(bedrock_client):
Expand Down
Loading