diff --git a/pyproject.toml b/pyproject.toml index 7b7247c..160b6f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "langchain>=0.3.7", "openai>=1.58.1", "pydantic>=2.9.2", - "og-test-v2-x402==0.0.11" + "og-test-v2-x402==0.0.12.dev3" ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index df03caa..4a6d40e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ requests>=2.32.3 langchain>=0.3.7 openai>=1.58.1 pydantic>=2.9.2 -og-test-v2-x402==0.0.11 \ No newline at end of file +og-test-v2-x402==0.0.12.dev3 \ No newline at end of file diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index 082f706..aa6a35d 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,15 +6,22 @@ into existing applications and agent frameworks. """ +from ..client.llm import LLM from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + private_key: str | None = None, + model_cid: TEE_LLM | str | None = None, + model: TEE_LLM | str | None = None, max_tokens: int = 300, + temperature: float = 0.0, x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: LLM | None = None, + rpc_url: str | None = None, + tee_registry_address: str | None = None, + llm_server_url: str | None = None, ) -> OpenGradientChatModel: """ Returns an OpenGradient LLM that implements LangChain's LLM interface @@ -22,9 +29,14 @@ def langchain_adapter( """ return OpenGradientChatModel( private_key=private_key, - model_cid=model_cid, + client=client, + model_cid=model_cid or model, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, + rpc_url=rpc_url, + tee_registry_address=tee_registry_address, + llm_server_url=llm_server_url, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 4f238a5..bcb02f4 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,29 +1,34 @@ # mypy: ignore-errors import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from enum import Enum +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + ChatMessage, HumanMessage, SystemMessage, ToolCall, ) -from langchain_core.messages.tool import ToolMessage +from langchain_core.messages.tool import ToolCallChunk, ToolMessage from langchain_core.outputs import ( ChatGeneration, + ChatGenerationChunk, ChatResult, ) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import PrivateAttr from ..client.llm import LLM -from ..types import TEE_LLM, x402SettlementMode +from ..types import StreamChunk, TEE_LLM, TextGenerationOutput, x402SettlementMode __all__ = ["OpenGradientChatModel"] @@ -47,7 +52,29 @@ def _extract_content(content: Any) -> str: return str(content) if content else "" -def _parse_tool_call(tool_call: Dict) -> ToolCall: +def _parse_tool_args(raw_args: Any) -> Dict[str, Any]: + if isinstance(raw_args, dict): + return raw_args + if raw_args is None or raw_args == "": + return {} + if isinstance(raw_args, str): + try: + parsed = json.loads(raw_args) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + + +def _serialize_tool_args(raw_args: Any) -> str: + if raw_args is None: + return "{}" + if isinstance(raw_args, str): + return raw_args + return json.dumps(raw_args) + + +def _parse_tool_call(tool_call: Dict[str, Any]) -> ToolCall: """Parse a tool call from the API response. Handles both flat format {"id", "name", "arguments"} and @@ -58,86 +85,191 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall: return ToolCall( id=tool_call.get("id", ""), name=func["name"], - args=json.loads(func.get("arguments", "{}")), + args=_parse_tool_args(func.get("arguments")), ) return ToolCall( id=tool_call.get("id", ""), name=tool_call["name"], - args=json.loads(tool_call.get("arguments", "{}")), + args=_parse_tool_args(tool_call.get("arguments")), + ) + + +def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> ToolCallChunk: + if "function" in tool_call: + func = tool_call.get("function", {}) + name = func.get("name") + raw_args = func.get("arguments") + else: + name = tool_call.get("name") + raw_args = tool_call.get("arguments") + + args: Optional[str] + if raw_args is None: + args = None + elif isinstance(raw_args, str): + args = raw_args + else: + args = json.dumps(raw_args) + + return ToolCallChunk( + id=tool_call.get("id"), + index=tool_call.get("index", default_index), + name=name, + args=args, ) +def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro_factory()) + + raise RuntimeError( + "Synchronous LangChain calls cannot run inside an active event loop for this adapter. " + "Use `ainvoke`/`astream` instead of `invoke`/`stream`." + ) + + +def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]: + if isinstance(model, Enum): + model_str = str(model.value) + else: + model_str = str(model) + if "/" not in model_str: + raise ValueError( + f"Unsupported model value '{model_str}'. " + "Expected provider/model format (for example: 'openai/gpt-5')." + ) + return model + + class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" - model_cid: str + model_cid: Union[TEE_LLM, str] max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.BATCH_HASHED + temperature: float = 0.0 + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED _llm: LLM = PrivateAttr() + _owns_client: bool = PrivateAttr(default=False) _tools: List[Dict] = PrivateAttr(default_factory=list) + _tool_choice: Optional[str] = PrivateAttr(default=None) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + private_key: Optional[str] = None, + model_cid: Optional[Union[TEE_LLM, str]] = None, + model: Optional[Union[TEE_LLM, str]] = None, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + temperature: float = 0.0, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: Optional[LLM] = None, + rpc_url: Optional[str] = None, + tee_registry_address: Optional[str] = None, + llm_server_url: Optional[str] = None, **kwargs, ): + resolved_model_cid = model_cid or model + if resolved_model_cid is None: + raise ValueError("model_cid (or model) is required.") + resolved_model_cid = _validate_model_string(resolved_model_cid) super().__init__( - model_cid=model_cid, + model_cid=resolved_model_cid, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._llm = LLM(private_key=private_key) + + if client is not None: + self._llm = client + self._owns_client = False + return + + if not private_key: + raise ValueError("private_key is required when client is not provided.") + + llm_kwargs: Dict[str, Any] = {} + if rpc_url is not None: + llm_kwargs["rpc_url"] = rpc_url + if tee_registry_address is not None: + llm_kwargs["tee_registry_address"] = tee_registry_address + if llm_server_url is not None: + llm_kwargs["llm_server_url"] = llm_server_url + + self._llm = LLM(private_key=private_key, **llm_kwargs) + self._owns_client = True @property def _llm_type(self) -> str: return "opengradient" + async def aclose(self) -> None: + if self._owns_client: + await self._llm.close() + + def close(self) -> None: + if self._owns_client: + _run_coro_sync(self._llm.close) + def bind_tools( self, tools: Sequence[ Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 ], + *, + tool_choice: Optional[str] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tools to the model.""" - tool_dicts: List[Dict] = [] + strict = kwargs.get("strict") + self._tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools] + self._tool_choice = tool_choice or kwargs.get("tool_choice") - for tool in tools: - if isinstance(tool, BaseTool): - tool_dicts.append( - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": ( - tool.args_schema.model_json_schema() - if hasattr(tool, "args_schema") and tool.args_schema is not None - else {} - ), - }, - } - ) - else: - tool_dicts.append(tool) + return self - self._tools = tool_dicts + @staticmethod + def _stream_chunk_to_generation(chunk: StreamChunk) -> ChatGenerationChunk: + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None - return self + usage = None + if chunk.usage is not None: + usage = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - sdk_messages = [] + tool_call_chunks: List[ToolCallChunk] = [] + if delta and delta.tool_calls: + for index, tool_call in enumerate(delta.tool_calls): + tool_call_chunks.append(_parse_tool_call_chunk(tool_call, index)) + + message_chunk = AIMessageChunk( + content=_extract_content(delta.content if delta else ""), + tool_call_chunks=tool_call_chunks, + usage_metadata=usage, + ) + + generation_info: Dict[str, Any] = {} + if choice and choice.finish_reason is not None: + generation_info["finish_reason"] = choice.finish_reason + + for key in ["tee_signature", "tee_timestamp", "tee_id", "tee_endpoint", "tee_payment_address"]: + value = getattr(chunk, key, None) + if value is not None: + generation_info[key] = value + + return ChatGenerationChunk( + message=message_chunk, + generation_info=generation_info or None, + ) + + def _convert_messages_to_sdk(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -148,9 +280,12 @@ def _generate( if message.tool_calls: msg["tool_calls"] = [ { - "id": call["id"], + "id": call.get("id", ""), "type": "function", - "function": {"name": call["name"], "arguments": json.dumps(call["args"])}, + "function": { + "name": call["name"], + "arguments": _serialize_tool_args(call.get("args")), + }, } for call in message.tool_calls ] @@ -163,33 +298,125 @@ def _generate( "tool_call_id": message.tool_call_id, } ) + elif isinstance(message, ChatMessage): + sdk_messages.append({"role": message.role, "content": _extract_content(message.content)}) else: raise ValueError(f"Unexpected message type: {message}") + return sdk_messages - chat_output = asyncio.run( - self._llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - ) + def _build_chat_kwargs(self, sdk_messages: List[Dict[str, Any]], stop: Optional[List[str]], stream: bool, **kwargs: Any) -> Dict[str, Any]: + x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode) + if isinstance(x402_settlement_mode, str): + x402_settlement_mode = x402SettlementMode(x402_settlement_mode) + model = kwargs.get("model", self.model_cid) + model = _validate_model_string(model) + return { + "model": model, + "messages": sdk_messages, + "stop_sequence": stop, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "tools": kwargs.get("tools", self._tools), + "tool_choice": kwargs.get("tool_choice", self._tool_choice), + "x402_settlement_mode": x402_settlement_mode, + "stream": stream, + } + + @staticmethod + def _build_chat_result(chat_output: TextGenerationOutput) -> ChatResult: finish_reason = chat_output.finish_reason or "" chat_response = chat_output.chat_output or {} + response_content = _extract_content(chat_response.get("content", "")) if chat_response.get("tool_calls"): tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) + ai_message = AIMessage(content=response_content, tool_calls=tool_calls) + else: + ai_message = AIMessage(content=response_content) + + generation_info = {"finish_reason": finish_reason} if finish_reason else {} + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info=generation_info)]) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs)) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = await self._llm.chat(**chat_kwargs) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + try: + asyncio.get_running_loop() + except RuntimeError: + pass else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + raise RuntimeError( + "Synchronous stream cannot run inside an active event loop for this adapter. " + "Use `astream` instead." + ) + + loop = asyncio.new_event_loop() + try: + stream = loop.run_until_complete(self._llm.chat(**chat_kwargs)) + stream_iter = cast(AsyncIterator[StreamChunk], stream) + + while True: + try: + chunk = loop.run_until_complete(stream_iter.__anext__()) + except StopAsyncIteration: + break + yield self._stream_chunk_to_generation(chunk) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + stream = await self._llm.chat(**chat_kwargs) + async for chunk in cast(AsyncIterator[StreamChunk], stream): + yield self._stream_chunk_to_generation(chunk) @property def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_cid, + "temperature": self.temperature, + "max_tokens": self.max_tokens, } diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index a345caa..1be1306 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -3,8 +3,9 @@ import json import logging import ssl +import threading from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union from eth_account import Account from eth_account.account import LocalAccount @@ -19,6 +20,7 @@ from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -91,8 +93,13 @@ def __init__( ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True + self._reset_lock = threading.Lock() - # x402 client and signer + # x402 client/signer/http stack + self._init_x402_stack() + + def _init_x402_stack(self) -> None: + """Initialize x402 signer/client/http stack.""" signer = EthAccountSignerv2(self._wallet_account) self._x402_client = x402Clientv2() register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) @@ -100,6 +107,51 @@ def __init__( # httpx.AsyncClient subclass - construction is sync, connections open lazily self._http_client = x402HttpxClientv2(self._x402_client, verify=self._tls_verify) + async def _reset_x402_stack(self) -> None: + """Reset x402 state and underlying HTTP client.""" + with self._reset_lock: + old_http_client = self._http_client + self._init_x402_stack() + + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous x402 HTTP client during reset.", exc_info=True) + + @staticmethod + def _is_invalid_payment_required_error(exc: Exception) -> bool: + """Detect the known stale-session x402 failure mode.""" + visited: set[int] = set() + current: Optional[BaseException] = exc + + while current is not None and id(current) not in visited: + visited.add(id(current)) + msg = str(current).lower() + if "invalid payment required response" in msg: + return True + current = current.__cause__ or current.__context__ + return False + + async def _retry_once_on_invalid_payment_required( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Retry once after resetting x402 state for recoverable payment errors.""" + try: + return await call() + except Exception as first_error: + if not self._is_invalid_payment_required_error(first_error): + raise + + logger.warning( + "Recoverable x402 payment error during %s; resetting x402 client and retrying once: %s", + operation_name, + first_error, + ) + await self._reset_x402_stack() + return await call() + # ── TEE resolution ────────────────────────────────────────────────── @staticmethod @@ -239,7 +291,7 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, @@ -256,6 +308,9 @@ async def completion( tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._retry_once_on_invalid_payment_required("completion", _request) except RuntimeError: raise except Exception as e: @@ -326,7 +381,7 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, @@ -356,6 +411,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._retry_once_on_invalid_payment_required("chat", _request) except RuntimeError: raise except Exception as e: @@ -391,15 +449,29 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) - async with self._http_client.stream( - "POST", - self._tee_endpoint + _CHAT_ENDPOINT, - json=payload, - headers=headers, - timeout=_REQUEST_TIMEOUT, - ) as response: - async for chunk in self._parse_sse_response(response): - yield chunk + retried = False + while True: + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response): + yield chunk + return + except Exception as e: + if (not retried) and self._is_invalid_payment_required_error(e): + retried = True + logger.warning( + "Recoverable x402 payment error during stream; resetting x402 client and retrying once: %s", + e, + ) + await self._reset_x402_stack() + continue + raise async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index e651ab4..1747c1d 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -11,7 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -52,9 +53,24 @@ def test_initialization_custom_settlement_mode(self, mock_llm_client): ) assert model.x402_settlement_mode == x402SettlementMode.PRIVATE + def test_initialization_with_existing_client(self): + with patch("src.opengradient.agents.og_langchain.LLM") as MockLLM: + existing_client = MagicMock() + model = OpenGradientChatModel(private_key=None, client=existing_client, model_cid=TEE_LLM.GPT_5) + assert model._llm is existing_client + MockLLM.assert_not_called() + + def test_initialization_without_private_key_or_client_raises(self): + with pytest.raises(ValueError, match="private_key is required"): + OpenGradientChatModel(private_key=None, model_cid=TEE_LLM.GPT_5) + + def test_initialization_with_invalid_model_string_raises(self): + with pytest.raises(ValueError, match="provider/model format"): + OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid="gpt-5") + def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0, "max_tokens": 300} class TestGenerate: @@ -156,6 +172,24 @@ def test_empty_chat_output(self, model, mock_llm_client): assert result.generations[0].message.content == "" + def test_generate_with_invalid_model_kwarg_raises(self, model): + with pytest.raises(ValueError, match="provider/model format"): + model._generate([HumanMessage(content="Hi")], model="gpt-5") + + def test_sync_generate_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `ainvoke`/`astream`"): + model._generate([HumanMessage(content="Hi")]) + + asyncio.run(run_test()) + + def test_sync_stream_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `astream`"): + next(model._stream([HumanMessage(content="Hi")])) + + asyncio.run(run_test()) + class TestMessageConversion: def test_converts_all_message_types(self, model, mock_llm_client): @@ -215,8 +249,11 @@ def test_passes_correct_params_to_client(self, model, mock_llm_client): messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], max_tokens=300, + temperature=0.0, tools=[], + tool_choice=None, x402_settlement_mode=x402SettlementMode.BATCH_HASHED, + stream=False, ) @@ -306,3 +343,77 @@ def test_nested_function_format(self): assert tc["name"] == "bar" assert tc["args"] == {"y": 2} assert tc["id"] == "2" + + +class TestAsyncPaths: + def test_agenerate(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "Hello async!"}, + ) + + result = asyncio.run(model._agenerate([HumanMessage(content="Hi")])) + assert result.generations[0].message.content == "Hello async!" + + def test_ainvoke(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "pong"}, + ) + + message = asyncio.run(model.ainvoke([HumanMessage(content="ping")])) + assert message.content == "pong" + + def test_astream(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(role="assistant", content="Hel"), index=0)], + model="gpt-5", + ) + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(content="lo"), index=0, finish_reason="stop")], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + output_text = "".join(chunk.content for chunk in chunks if chunk.content) + assert output_text == "Hello" + + def test_astream_tool_call_chunk(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta( + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"test"}'}, + } + ] + ), + index=0, + finish_reason="tool_calls", + ) + ], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + assert chunks[0].tool_call_chunks[0]["id"] == "call_1" + assert chunks[0].tool_call_chunks[0]["name"] == "search" diff --git a/tests/llm_test.py b/tests/llm_test.py index 3f068f3..eed3042 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -7,7 +7,7 @@ import json from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -361,6 +361,32 @@ async def test_http_error_raises_opengradient_error(self, fake_http): with pytest.raises(RuntimeError, match="TEE LLM chat failed"): await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + async def test_retries_once_on_invalid_payment_required(self, fake_http): + fake_http.set_response( + 200, + { + "choices": [{"message": {"role": "assistant", "content": "recovered"}, "finish_reason": "stop"}], + }, + ) + llm = _make_llm() + llm._reset_x402_stack = AsyncMock(return_value=None) + original_post = llm._http_client.post + attempts = {"count": 0} + + async def flaky_post(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + raise RuntimeError("Failed to handle payment: Invalid payment required response") + return await original_post(*args, **kwargs) + + llm._http_client.post = flaky_post + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "recovered" + assert attempts["count"] == 2 + llm._reset_x402_stack.assert_awaited_once() + # ── Streaming tests ────────────────────────────────────────────────── @@ -469,6 +495,40 @@ async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): assert chunks[0].choices[0].delta.tool_calls == [{"id": "tc1"}] assert chunks[0].choices[0].finish_reason == "tool_calls" + async def test_stream_retries_once_on_invalid_payment_required(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + llm = _make_llm() + llm._reset_x402_stack = AsyncMock(return_value=None) + original_stream = llm._http_client.stream + attempts = {"count": 0} + + @asynccontextmanager + async def flaky_stream(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + raise RuntimeError("Failed to handle payment: Invalid payment required response") + async with original_stream(*args, **kwargs) as response: + yield response + + llm._http_client.stream = flaky_stream + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [chunk async for chunk in gen] + + assert attempts["count"] == 2 + assert chunks[-1].choices[0].delta.content == "ok" + llm._reset_x402_stack.assert_awaited_once() + # ── ensure_opg_approval tests ──────────────────────────────────────── diff --git a/uv.lock b/uv.lock index bb26279..5e1a9dc 100644 --- a/uv.lock +++ b/uv.lock @@ -1610,15 +1610,15 @@ wheels = [ [[package]] name = "og-test-v2-x402" -version = "0.0.9" +version = "0.0.12.dev3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/de/fd26c297113c483f62f3a5ee5fc535e81f9413edc68d1bf9d2db4ba62dd4/og_test_v2_x402-0.0.9.tar.gz", hash = "sha256:f5353be907c7224371214d40ec8dc125ee0633e3dbd9deadf6e43c904a7a9328", size = 892006, upload-time = "2026-02-17T16:16:34.028Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/37/30cb7b742aa61df537ea5ff4d4d6ef6945044fe353b7e200b82242dfbda6/og_test_v2_x402-0.0.12.dev3.tar.gz", hash = "sha256:564dff796fcf3cf2974bdd00c2ee08ee2b54e1f88e5c321010fb3960a08966f0", size = 899129, upload-time = "2026-03-12T12:53:53.397Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/da/bd0f6670d9d577b10d13a9a68371b8fd40e1b24fd30dd690b2aa497eea81/og_test_v2_x402-0.0.9-py3-none-any.whl", hash = "sha256:80257701e8a1909ec5fba434482aa2cdcd9de2a7868b99cff70cf1763c8a53b0", size = 945014, upload-time = "2026-02-17T16:16:32.107Z" }, + { url = "https://files.pythonhosted.org/packages/be/be/3486fd0c55f256bf7e9872d17913eca5b71e6ab357a237a8055a11f7f0de/og_test_v2_x402-0.0.12.dev3-py3-none-any.whl", hash = "sha256:ee23e17b301bbaece78c3aadee9f12c157cfbb561432b6561bde1d6cd305e90f", size = 952165, upload-time = "2026-03-12T12:53:51.452Z" }, ] [[package]] @@ -1642,7 +1642,7 @@ wheels = [ [[package]] name = "opengradient" -version = "0.7.3" +version = "0.8.0" source = { editable = "." } dependencies = [ { name = "click" }, @@ -1664,7 +1664,7 @@ requires-dist = [ { name = "firebase-rest-api", specifier = ">=1.11.0" }, { name = "langchain", specifier = ">=0.3.7" }, { name = "numpy", specifier = ">=1.26.4" }, - { name = "og-test-v2-x402", specifier = "==0.0.9" }, + { name = "og-test-v2-x402", specifier = "==0.0.12.dev3" }, { name = "openai", specifier = ">=1.58.1" }, { name = "pydantic", specifier = ">=2.9.2" }, { name = "requests", specifier = ">=2.32.3" },