From eb7e43137f525f1d08222c67713a4bf24e2ee94a Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Thu, 25 Jun 2020 11:28:34 +0300 Subject: [PATCH 1/5] Allow usage of async partial methods --- starlette/routing.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index ac48169b9..9fd2a02b1 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,4 +1,5 @@ import asyncio +import functools import inspect import re import traceback @@ -33,7 +34,10 @@ def request_response(func: typing.Callable) -> ASGIApp: Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = asyncio.iscoroutinefunction(func) + if isinstance(func, functools.partial): + is_coroutine = asyncio.iscoroutinefunction(func.func) + else: + is_coroutine = asyncio.iscoroutinefunction(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) @@ -169,7 +173,10 @@ def __init__( self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + if isinstance(endpoint, functools.partial): + endpoint_handler = endpoint.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: From c8a6160010c5f6881fd825abf78c45e399a2854d Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Thu, 25 Jun 2020 12:16:25 +0300 Subject: [PATCH 2/5] Added test for partial async endpoint --- tests/test_routing.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_routing.py b/tests/test_routing.py index 36b3a69e7..a71c7eab8 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,3 +1,4 @@ +import functools import uuid import pytest @@ -488,3 +489,22 @@ def test_standalone_ws_route_does_not_match(): client = TestClient(app) with pytest.raises(WebSocketDisconnect): client.websocket_connect("/invalid") + + +async def _partial_async_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") + +partial_async_app = Router( + routes=[ + Route('/', partial_async_endpoint) + ] +) + + +def test_partial_async_endpoint(): + response = TestClient(partial_async_app).get('/') + assert response.status_code == 200 + assert response.json() == {"arg": "foo"} From 2230bab505a584fee71b9f0b7e995d9c73e2212f Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Sun, 28 Jun 2020 11:40:01 +0300 Subject: [PATCH 3/5] Double quotes vs single quotes --- tests/test_routing.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index a71c7eab8..f32de9ec0 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -497,14 +497,10 @@ async def _partial_async_endpoint(arg, request): partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") -partial_async_app = Router( - routes=[ - Route('/', partial_async_endpoint) - ] -) +partial_async_app = Router(routes=[Route("/", partial_async_endpoint)]) def test_partial_async_endpoint(): - response = TestClient(partial_async_app).get('/') + response = TestClient(partial_async_app).get("/") assert response.status_code == 200 assert response.json() == {"arg": "foo"} From a739da4eb4cadb5fb24f274bcb492e261ba0b865 Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Wed, 1 Jul 2020 20:32:04 +0300 Subject: [PATCH 4/5] Support multiple levels of partials, check Python < 3.8 --- starlette/routing.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 9fd2a02b1..9e53916ca 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -2,6 +2,7 @@ import functools import inspect import re +import sys import traceback import typing from enum import Enum @@ -29,15 +30,23 @@ class Match(Enum): FULL = 2 +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: + """ + Correctly determines if an object is a coroutine function, + with a fix for partials on Python < 3.8. + """ + if sys.version_info < (3, 8): + while isinstance(obj, functools.partial): + obj = obj.func + return inspect.iscoroutinefunction(obj) + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - if isinstance(func, functools.partial): - is_coroutine = asyncio.iscoroutinefunction(func.func) - else: - is_coroutine = asyncio.iscoroutinefunction(func) + is_coroutine = iscoroutinefunction_or_partial(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) @@ -174,8 +183,8 @@ def __init__( self.include_in_schema = include_in_schema endpoint_handler = endpoint - if isinstance(endpoint, functools.partial): - endpoint_handler = endpoint.func + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) From 141ff2e8ef34fda4cc1535856e40f7956d82f1cf Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Fri, 24 Jul 2020 21:18:56 +0300 Subject: [PATCH 5/5] Skip coverage for py3.8 branch --- starlette/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 9e53916ca..3b4e97c6a 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -35,7 +35,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: Correctly determines if an object is a coroutine function, with a fix for partials on Python < 3.8. """ - if sys.version_info < (3, 8): + if sys.version_info < (3, 8): # pragma: no cover while isinstance(obj, functools.partial): obj = obj.func return inspect.iscoroutinefunction(obj)