From ae9dc8897885ad26461083682dd7ba008d5af3cb Mon Sep 17 00:00:00 2001 From: Carlos Chinchilla Corbacho <188046461+cchinchilla-dev@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:40:25 +0100 Subject: [PATCH 1/2] feat: add async context manager support to BaseClient (#688) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Follow-up to #682, as suggested by @ishymko in the [review](https://github.com/a2aproject/a2a-python/pull/682#pullrequestreview-3789126544). This extends the async context manager pattern to `BaseClient`, which wraps `ClientTransport` and also exposes a `close()` method. Fixes #674 🦕 ## Problem `BaseClient` delegates resource cleanup to its underlying `ClientTransport` via `close()`, but doesn't implement `__aenter__`/`__aexit__`. This means clients cannot be used with `async with`, leading to the same resource leak risk that #682 solved for transports: ```python client = BaseClient(card=card, config=config, transport=transport, consumers=[], middleware=[]) result = await client.send_message(msg) # if this raises, close() is never called await client.close() ``` ## Fix Added `__aenter__` and `__aexit__` methods to `BaseClient` in `src/a2a/client/base_client.py`: `__aenter__` returns `self` `__aexit__ `awaits `close()` This enables the standard async context manager pattern: ```python async with BaseClient(card=card, config=config, transport=transport, consumers=[], middleware=[]) as client: async for event in client.send_message(msg): ... # close() called automatically, even on exceptions ``` This is a non-breaking, additive change. Calling `close()` manually or via `try/finally` continues to work exactly as before. ## Test Tests were added to `tests/client/test_base_client.py`, following the same approach as the `ClientTransport` tests from #682. Release-As: 0.3.23 --- src/a2a/client/base_client.py | 16 ++++++++++++++++ tests/client/test_base_client.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index c870f329..09b2891d 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,6 +1,9 @@ from collections.abc import AsyncIterator, Callable +from types import TracebackType from typing import Any +from typing_extensions import Self + from a2a.client.client import ( Client, ClientCallContext, @@ -43,6 +46,19 @@ def __init__( self._config = config self._transport = transport + async def __aenter__(self) -> Self: + """Enters the async context manager, returning the client itself.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exits the async context manager, ensuring close() is called.""" + await self.close() + async def send_message( self, request: Message, diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index 04bebb3b..4fd6ff9c 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -87,6 +87,26 @@ async def test_transport_async_context_manager_on_exception() -> None: transport.close.assert_awaited_once() +@pytest.mark.asyncio +async def test_base_client_async_context_manager( + base_client: BaseClient, mock_transport: AsyncMock +) -> None: + async with base_client as client: + assert client is base_client + mock_transport.close.assert_not_awaited() + mock_transport.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_base_client_async_context_manager_on_exception( + base_client: BaseClient, mock_transport: AsyncMock +) -> None: + with pytest.raises(RuntimeError, match='boom'): + async with base_client: + raise RuntimeError('boom') + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio async def test_send_message_streaming( base_client: BaseClient, mock_transport: MagicMock, sample_message: Message From 2acd838796d44ab9bfe6ba8c8b4ea0c2571a59dc Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Fri, 13 Feb 2026 16:23:52 +0100 Subject: [PATCH 2/2] fix: Improve error handling for Timeout exceptions on REST and JSON-RPC clients (#690) This PR standardizes timeout error handling across the JSON-RPC and REST clients. Previously, only the JSON-RPC client (in non-streaming mode) handled `ReadTimeout` exceptions, while streaming calls and the REST client catch them incorrectly. Updating both `JsonRpcTransport` and `RestTransport` to catch the base httpx.TimeoutException, all timeout types (Read, Connect, Write, Pool) are consistently caught and wrapped in A2AClientTimeoutError. This ensures consistent behavior for API consumers regardless of the transport (REST vs JSON-RPC) or mode (Streaming vs Non-Streaming) being used, preventing generic errors when network timeouts occur. --- src/a2a/client/transports/jsonrpc.py | 6 +++- src/a2a/client/transports/rest.py | 12 ++++++- .../client/transports/test_jsonrpc_client.py | 32 +++++++++++++++++ tests/client/transports/test_rest_client.py | 36 ++++++++++++++++++- 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index a58a7cab..bfa0b951 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -184,6 +184,8 @@ async def send_message_streaming( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: @@ -208,7 +210,7 @@ async def _send_request( ) response.raise_for_status() return response.json() - except httpx.ReadTimeout as e: + except httpx.TimeoutException as e: raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e @@ -365,6 +367,8 @@ async def resubscribe( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 96df1e02..7a826cd6 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -10,7 +10,11 @@ from httpx_sse import SSEError, aconnect_sse from a2a.client.card_resolver import A2ACardResolver -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, +) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.extensions.common import update_extension_header @@ -159,6 +163,8 @@ async def send_message_streaming( event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: @@ -177,6 +183,8 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]: response = await self.httpx_client.send(request) response.raise_for_status() return response.json() + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: @@ -357,6 +365,8 @@ async def resubscribe( event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) + except httpx.TimeoutException as e: + raise A2AClientTimeoutError('Client Request timed out') from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 0f6bba5b..9725273f 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -599,6 +599,38 @@ async def test_send_message_client_timeout( assert 'Client Request timed out' in str(exc_info.value) + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_timeout( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( + 'Read timed out' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + assert 'Client Request timed out' in str(exc_info.value) + @pytest.mark.asyncio async def test_get_task_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index c889ebaf..8f2232fb 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -9,7 +9,7 @@ from httpx_sse import EventSource, ServerSentEvent from a2a.client import create_text_message_object -from a2a.client.errors import A2AClientHTTPError +from a2a.client.errors import A2AClientHTTPError, A2AClientTimeoutError from a2a.client.transports.rest import RestTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 @@ -50,6 +50,40 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): assert actual_extensions == expected_extensions +class TestRestTransport: + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_timeout( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = RestTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( + 'Read timed out' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + assert 'Client Request timed out' in str(exc_info.value) + + class TestRestTransportExtensions: @pytest.mark.asyncio async def test_send_message_with_default_extensions(