diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 3fa907995..bab4031ed 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -339,7 +339,7 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: - """Inject a cache point at the end of the last assistant message. + """Inject a cache point at the end of the last user message. Args: messages: List of messages to inject cache point into (modified in place). @@ -347,7 +347,7 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: if not messages: return - last_assistant_idx: int | None = None + last_user_idx: int | None = None for msg_idx, msg in enumerate(messages): content = msg.get("content", []) for block_idx, block in reversed(list(enumerate(content))): @@ -358,12 +358,12 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: msg_idx, block_idx, ) - if msg.get("role") == "assistant": - last_assistant_idx = msg_idx + if msg.get("role") == "user": + last_user_idx = msg_idx - if last_assistant_idx is not None and messages[last_assistant_idx].get("content"): - messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}}) - logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx) + if last_user_idx is not None and messages[last_user_idx].get("content"): + messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: """Find the index of the last user message containing text or image content. diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 66fe8ab00..89c4df70d 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2597,8 +2597,8 @@ def test_cache_strategy_none_for_non_claude(bedrock_client): assert model._cache_strategy is None -def test_inject_cache_point_adds_to_last_assistant(bedrock_client): - """Test that _inject_cache_point adds cache point to last assistant message.""" +def test_inject_cache_point_adds_to_last_user(bedrock_client): + """Test that _inject_cache_point adds cache point to last user message.""" model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") ) @@ -2611,13 +2611,14 @@ def test_inject_cache_point_adds_to_last_assistant(bedrock_client): model._inject_cache_point(cleaned_messages) - assert len(cleaned_messages[1]["content"]) == 2 - assert "cachePoint" in cleaned_messages[1]["content"][-1] - assert cleaned_messages[1]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[1]["content"]) == 1 -def test_inject_cache_point_no_assistant_message(bedrock_client): - """Test that _inject_cache_point does nothing when no assistant message exists.""" +def test_inject_cache_point_single_user_message(bedrock_client): + """Test that _inject_cache_point adds cache point to single user message.""" model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") ) @@ -2629,6 +2630,39 @@ def test_inject_cache_point_no_assistant_message(bedrock_client): model._inject_cache_point(cleaned_messages) assert len(cleaned_messages) == 1 + assert len(cleaned_messages[0]["content"]) == 2 + assert "cachePoint" in cleaned_messages[0]["content"][-1] + + +def test_inject_cache_point_empty_messages(bedrock_client): + """Test that _inject_cache_point handles empty messages list.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [] + model._inject_cache_point(cleaned_messages) + + assert cleaned_messages == [] + + +def test_inject_cache_point_with_tool_result_last_user(bedrock_client): + """Test that cache point is added to last user message even when it contains toolResult.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Use the tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" assert len(cleaned_messages[0]["content"]) == 1 @@ -2643,6 +2677,8 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client): formatted = model._format_bedrock_messages(messages) + assert len(formatted[0]["content"]) == 1 + assert "cachePoint" not in formatted[0]["content"][0] assert len(formatted[1]["content"]) == 1 assert "cachePoint" not in formatted[1]["content"][0] @@ -2664,8 +2700,8 @@ def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): formatted = model._format_bedrock_messages(original_messages) assert original_messages == messages_before - assert "cachePoint" not in original_messages[1]["content"][-1] - assert "cachePoint" in formatted[1]["content"][-1] + assert "cachePoint" not in original_messages[2]["content"][-1] + assert "cachePoint" in formatted[2]["content"][-1] def test_inject_cache_point_strips_existing_cache_points(bedrock_client): @@ -2685,12 +2721,13 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client): model._inject_cache_point(cleaned_messages) # All old cache points should be stripped - assert len(cleaned_messages[0]["content"]) == 1 # user: only text + assert len(cleaned_messages[0]["content"]) == 1 # first user: only text assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text + assert len(cleaned_messages[3]["content"]) == 1 # last assistant: only text - # New cache point should be at end of last assistant message - assert len(cleaned_messages[3]["content"]) == 2 - assert "cachePoint" in cleaned_messages[3]["content"][-1] + # New cache point should be at end of last user message + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client): @@ -2707,9 +2744,10 @@ def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client) formatted = model._format_bedrock_messages(messages) - assert len(formatted[1]["content"]) == 2 - assert "cachePoint" in formatted[1]["content"][-1] - assert formatted[1]["content"][-1]["cachePoint"]["type"] == "default" + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert formatted[0]["content"][-1]["cachePoint"]["type"] == "default" + assert len(formatted[1]["content"]) == 1 def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client): @@ -2725,8 +2763,9 @@ def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedro formatted = model._format_bedrock_messages(messages) - assert len(formatted[1]["content"]) == 2 - assert "cachePoint" in formatted[1]["content"][-1] + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert len(formatted[1]["content"]) == 1 def test_find_last_user_text_message_index_no_user_messages(bedrock_client):