8000 feat(py/dotpromptz): add resolve_json_schema and test case for ResolverCallable that returns an asyncio.Future by yesudeep · Pull Request #211 · google/dotprompt · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(py/dotpromptz): add resolve_json_schema and test case for ResolverCallable that returns an asyncio.Future #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions python/dotpromptz/src/dotpromptz/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@

## Key Operations

| Function | Description |
|-------------------|----------------------------------------------------------------------------|
| `resolve` | Core async function to resolve a named object using a given resolver. |
| | Handles both sync/async resolvers and sync functions returning awaitables. |
| `resolve_tool` | Helper async function specifically for resolving tool names. |
| `resolve_partial` | Helper async function specifically for resolving partial names. |
| Function | Description |
|-----------------------|----------------------------------------------------------------------------|
| `resolve` | Core async function to resolve a named object using a given resolver. |
| | Handles both sync/async resolvers and sync functions returning awaitables. |
| `resolve_tool` | Helper async function specifically for resolving tool names. |
| `resolve_partial` | Helper async function specifically for resolving partial names. |
| `resolve_json_schema` | Helper async function specifically for resolving JSON schemas. |

The `resolve` function handles both sync and async resolvers. If the resolver is
sync, it is run in a thread pool to avoid blocking the event loop. If the
resolver is async, it is awaited directly.

The `resolve_tool` and `resolve_partial` functions are convenience wrappers around
`resolve` that handle the specific types of resolvers for tools and partials.
The `resolve_*` functions are convenience wrappers around `resolve` that handle
the specific types of resolvers for tools, partials, and schemas.
"""

import inspect
Expand All @@ -44,7 +45,13 @@
import anyio

from dotpromptz.errors import ResolverFailedError
from dotpromptz.typing import PartialResolver, ToolDefinition, ToolResolver
from dotpromptz.typing import (
JsonSchema,
PartialResolver,
SchemaResolver,
ToolDefinition,
ToolResolver,
)

# For compatibility with Python 3.10.
ResolverCallable = Callable[[str], Awaitable[Any] | Any]
Expand Down Expand Up @@ -147,3 +154,21 @@ async def resolve_partial(name: str, resolver: PartialResolver) -> str:
TypeError: If the resolver is not callable or returns an invalid type.
"""
return await resolve(name, 'partial', resolver)


async def resolve_json_schema(name: str, resolver: SchemaResolver) -> JsonSchema:
"""Resolve a JSON schema using the provided resolver.

Args:
name: The name of the JSON schema to resolve.
resolver: The JSON schema resolver callable.

Returns:
The resolved JSON schema.

Raises:
LookupError: If the resolver returns None for the schema.
ResolverFailedError: For exceptions raised by the resolver.
TypeError: If the resolver is not callable or returns an invalid type.
"""
return await resolve(name, 'schema', resolver)
73 changes: 69 additions & 4 deletions python/dotpromptz/tests/dotpromptz/resolvers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
raising `LookupError`.
* Correctly wrapping exceptions from sync resolvers in `ResolverFailedError`.
* Correctly wrapping exceptions from async resolvers in `ResolverFailedError`.
* Handling synchronous resolvers returning an `asyncio.Future`.

## `resolve_tool` & `resolve_partial`
## `resolve_*` functions

* Successful resolution via the core `resolve` function.
* Successful resolution to the correct type via the core `resolve` function.
* Correct propagation of errors (e.g., `ResolverFailedError`, `LookupError`)
from the core `resolve` function.
"""
Expand All @@ -40,8 +41,8 @@
from typing import Any

from dotpromptz.errors import ResolverFailedError
from dotpromptz.resolvers import resolve, resolve_partial, resolve_tool
from dotpromptz.typing import ToolDefinition
from dotpromptz.resolvers import resolve, resolve_json_schema, resolve_partial, resolve_tool
from dotpromptz.typing import JsonSchema, ToolDefinition


class MockSyncResolver:
Expand Down Expand Up @@ -81,6 +82,25 @@ def __call__(self, name: str) -> Awaitable[Any] | None:
return None


class MockSyncReturningFutureResolver:
"""Mock sync resolver that returns an asyncio.Future."""

def __init__(self, data: dict[str, Any], loop: asyncio.AbstractEventLoop) -> None:
"""Initialize the mock resolver."""
self._data = data
self._loop = loop

def __call__(self, name: str) -> asyncio.Future[Any] | None:
"""Return a future object if name is found."""
value = self._data.get(name)
if value is not None:
future: asyncio.Future[Any] = self._loop.create_future()
# Use call_soon_threadsafe to set the result in the event loop.
self._loop.call_soon_threadsafe(future.set_result, value)
return future
return None


class MockAsyncResolver:
"""Mock async resolver callable."""

Expand All @@ -100,6 +120,7 @@ async def __call__(self, name: str) -> Any:

mock_tool_def = ToolDefinition(name='test_tool', inputSchema={})
mock_partial_content = 'This is a partial.'
mock_json_schema: JsonSchema = {'type': 'string', 'description': 'A test schema'}


class TestResolve(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -158,6 +179,22 @@ async def test_resolve_async_resolver_raises_error(self) -> None:
await resolve('obj', 'test', resolver)
self.assertIs(cm.exception.__cause__, original_error)

async def test_resolve_sync_resolver_returns_future(self) -> None:
"""Test successful resolution with a sync resolver returning a Future."""
loop = asyncio.get_running_loop()
resolver = MockSyncReturningFutureResolver({'obj_future': 'value_future'}, loop)
result: str = await resolve('obj_future', 'test', resolver)
self.assertEqual(result, 'value_future')

async def test_resolve_resolver_none(self) -> None:
"""Test LookupError when resolver returns None."""
resolver_sync = MockSyncResolver({})
resolver_async = MockAsyncResolver({})
with self.assertRaisesRegex(LookupError, "test resolver for 'not_found' returned None"):
await resolve('not_found', 'test', resolver_sync)
with self.assertRaisesRegex(LookupError, "test resolver for 'not_found' returned None"):
await resolve('not_found', 'test', resolver_async)


class TestResolveTool(unittest.IsolatedAsyncioTestCase):
"""Tests for tool resolver functions."""
Expand Down Expand Up @@ -190,5 +227,33 @@ async def test_resolve_partial_fails(self) -> None:
await resolve_partial('missing_partial', MockSyncResolver({}))


class TestResolveJsonSchema(unittest.IsolatedAsyncioTestCase):
"""Tests for JSON schema resolver function."""

async def test_resolve_json_schema_success_sync(self) -> None:
"""Test successful schema resolution with sync resolver."""
resolver = MockSyncResolver({'MySchema': mock_json_schema})
result = await resolve_json_schema('MySchema', resolver)
self.assertEqual(result, mock_json_schema)

async def test_resolve_json_schema_success_async(self) -> None:
"""Test successful schema resolution with async resolver."""
resolver = MockAsyncResolver({'MySchema': mock_json_schema})
result = await resolve_json_schema('MySchema', resolver)
self.assertEqual(result, mock_json_schema)

async def test_resolve_json_schema_fails_error(self) -> None:
"""Test failing schema resolution propagates error."""
resolver = MockSyncResolver({}, error=TypeError('Schema Error'))
with self.assertRaisesRegex(ResolverFailedError, r'schema resolver failed for bad_schema; Schema Error'):
await resolve_json_schema('bad_schema', resolver)

async def test_resolve_json_schema_fails_none(self) -> None:
"""Test failing schema resolution propagates error when None is returned."""
resolver = MockAsyncResolver({})
with self.assertRaisesRegex(LookupError, r"schema resolver for 'missing_schema' returned None"):
await resolve_json_schema('missing_schema', resolver)


if __name__ == '__main__':
unittest.main()
Loading
0