Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections.abc import AsyncIterator, Callable
from types import TracebackType
from typing import Any

from typing_extensions import Self

from a2a.client.client import (
Client,
ClientCallContext,
Expand Down Expand Up @@ -43,6 +46,19 @@ def __init__(
self._config = config
self._transport = transport

async def __aenter__(self) -> Self:
"""Enters the async context manager, returning the client itself."""
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exits the async context manager, ensuring close() is called."""
await self.close()

async def send_message(
self,
request: Message,
Expand Down
6 changes: 5 additions & 1 deletion src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,46 +184,48 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The error message 'Client Request timed out' is hardcoded here and in other places within this file, as well as in src/a2a/client/transports/rest.py. To improve maintainability and avoid magic strings, consider defining this as a module-level constant in each file. For example:

_CLIENT_REQUEST_TIMEOUT_MESSAGE = 'Client Request timed out'

This constant can then be used whenever A2AClientTimeoutError is raised in the respective transport files.

except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(
self,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
response = await self.httpx_client.post(

Check notice on line 208 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (170-185)
self.url, json=rpc_request_payload, **(http_kwargs or {})
)
response.raise_for_status()
return response.json()
except httpx.ReadTimeout as e:
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,

Check notice on line 228 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (187-201)
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
Expand Down Expand Up @@ -365,6 +367,8 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
Expand Down
12 changes: 11 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from httpx_sse import SSEError, aconnect_sse

from a2a.client.card_resolver import A2ACardResolver
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
from a2a.client.errors import (
A2AClientHTTPError,
A2AClientJSONError,
A2AClientTimeoutError,
)
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.extensions.common import update_extension_header
Expand Down Expand Up @@ -159,38 +163,42 @@
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
return response.json()

Check notice on line 185 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (191-208)

Check notice on line 185 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (395-407)
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_post_request(
self,
target: str,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,

Check notice on line 201 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (214-228)
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
Expand Down Expand Up @@ -357,6 +365,8 @@
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except httpx.TimeoutException as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
Expand Down
20 changes: 20 additions & 0 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ async def test_transport_async_context_manager_on_exception() -> None:
transport.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_base_client_async_context_manager(
base_client: BaseClient, mock_transport: AsyncMock
) -> None:
async with base_client as client:
assert client is base_client
mock_transport.close.assert_not_awaited()
mock_transport.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_base_client_async_context_manager_on_exception(
base_client: BaseClient, mock_transport: AsyncMock
) -> None:
with pytest.raises(RuntimeError, match='boom'):
async with base_client:
raise RuntimeError('boom')
mock_transport.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_send_message_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
Expand Down
32 changes: 32 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,38 @@ async def test_send_message_client_timeout(

assert 'Client Request timed out' in str(exc_info.value)

@pytest.mark.asyncio
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
async def test_send_message_streaming_timeout(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
client = JsonRpcTransport(
httpx_client=mock_httpx_client, agent_card=mock_agent_card
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)
mock_event_source = AsyncMock(spec=EventSource)
mock_event_source.response = MagicMock(spec=httpx.Response)
mock_event_source.response.raise_for_status.return_value = None
mock_event_source.aiter_sse.side_effect = httpx.TimeoutException(
'Read timed out'
)
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)

with pytest.raises(A2AClientTimeoutError) as exc_info:
_ = [
item
async for item in client.send_message_streaming(request=params)
]

assert 'Client Request timed out' in str(exc_info.value)

@pytest.mark.asyncio
async def test_get_task_success(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
Expand Down
36 changes: 35 additions & 1 deletion tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from httpx_sse import EventSource, ServerSentEvent

from a2a.client import create_text_message_object
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientHTTPError, A2AClientTimeoutError
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.grpc import a2a_pb2
Expand Down Expand Up @@ -50,6 +50,40 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
assert actual_extensions == expected_extensions


class TestRestTransport:
@pytest.mark.asyncio
@patch('a2a.client.transports.rest.aconnect_sse')
async def test_send_message_streaming_timeout(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
client = RestTransport(
httpx_client=mock_httpx_client, agent_card=mock_agent_card
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)
mock_event_source = AsyncMock(spec=EventSource)
mock_event_source.response = MagicMock(spec=httpx.Response)
mock_event_source.response.raise_for_status.return_value = None
mock_event_source.aiter_sse.side_effect = httpx.TimeoutException(
'Read timed out'
)
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)

with pytest.raises(A2AClientTimeoutError) as exc_info:
_ = [
item
async for item in client.send_message_streaming(request=params)
]

assert 'Client Request timed out' in str(exc_info.value)


class TestRestTransportExtensions:
@pytest.mark.asyncio
async def test_send_message_with_default_extensions(
Expand Down
Loading