diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index a58a7cab..bfa0b951 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -184,6 +184,8 @@ async def send_message_streaming( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: @@ -208,7 +210,7 @@ async def _send_request( ) response.raise_for_status() return response.json() - except httpx.ReadTimeout as e: + except httpx.TimeoutException as e: raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e @@ -365,6 +367,8 @@ async def resubscribe( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 96df1e02..7a826cd6 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -10,7 +10,11 @@ from httpx_sse import SSEError, aconnect_sse from a2a.client.card_resolver import A2ACardResolver -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, +) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.extensions.common import update_extension_header @@ -159,6 +163,8 @@ async def send_message_streaming( event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: @@ -177,6 +183,8 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]: response = await self.httpx_client.send(request) response.raise_for_status() return response.json() + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: @@ -357,6 +365,8 @@ async def resubscribe( event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 0f6bba5b..9725273f 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -599,6 +599,38 @@ async def test_send_message_client_timeout( assert 'Client Request timed out' in str(exc_info.value) + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_timeout( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( + 'Read timed out' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + assert 'Client Request timed out' in str(exc_info.value) + @pytest.mark.asyncio async def test_get_task_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index c889ebaf..8f2232fb 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -9,7 +9,7 @@ from httpx_sse import EventSource, ServerSentEvent from a2a.client import create_text_message_object -from a2a.client.errors import A2AClientHTTPError +from a2a.client.errors import A2AClientHTTPError, A2AClientTimeoutError from a2a.client.transports.rest import RestTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 @@ -50,6 +50,40 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): assert actual_extensions == expected_extensions +class TestRestTransport: + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_timeout( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = RestTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( + 'Read timed out' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + assert 'Client Request timed out' in str(exc_info.value) + + class TestRestTransportExtensions: @pytest.mark.asyncio async def test_send_message_with_default_extensions(