8000 Add ExceptionHandler type by kristjanvalur · Pull Request #2048 · encode/starlette · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add ExceptionHandler type #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions starlette/applications.py
< 8000 td id="diff-3eb07c0eea72d98f1b844e07844c2a68b7c11c63956b67d5afdd9c7f79ab946fR52" data-line-number="52" class="blob-num blob-num-addition js-linkable-line-number js-blob-rnum">
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.middleware.exceptions import ExceptionHandler, ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
Expand Down Expand Up @@ -47,9 +47,12 @@ def __init__(
exception_handlers: typing.Optional[
typing.Mapping[
typing.Any,
typing.Callable[
[Request, Exception],
typing.Union[Response, typing.Awaitable[Response]],
typing.Union[
ExceptionHandler,
typing.Callable[
[Request, Exception],
typing.Union[Response, typing.Awaitable[Response]],
],
],
]
] = None,
Expand Down Expand Up @@ -80,7 +83,14 @@ def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: typing.Dict[
typing.Any, typing.Callable[[Request, Exception], Response]
typing.Any,
typing.Union[
ExceptionHandler,
typing.Callable[
[Request, Exception],
typing.Union[Response, typing.Awaitable[Response]],
],
],
] = {}

for key, value in self.exception_handlers.items():
Expand Down
11 changes: 9 additions & 2 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.middleware.exceptions import ExceptionHandler
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -137,11 +138,17 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: typing.Optional[typing.Callable] = None,
handler: typing.Optional[
typing.Union[
ExceptionHandler, typing.Callable[[typing.Any, typing.Any], typing.Any]
]
] = None,
debug: bool = False,
) -> None:
self.app = app
self.handler = handler
self.handler = (
handler.handler if isinstance(handler, ExceptionHandler) else handler
)
self.debug = debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down
54 changes: 52 additions & 2 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,43 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

_ET = typing.TypeVar("_ET", bound=Exception)


class ExceptionHandler:
@typing.overload
def __init__(
self,
handler: typing.Callable[
[Request, _ET], typing.Union[Response, typing.Awaitable[Response]]
],
) -> None:
...

@typing.overload
def __init__(
self,
handler: typing.Callable[
[WebSocket, _ET], typing.Union[None, typing.Awaitable[None]]
],
) -> None:
...

def __init__(self, handler: typing.Callable[[typing.Any, _ET], typing.Any]) -> None:
self.handler: typing.Callable[[typing.Any, typing.Any], typing.Any] = handler


class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
typing.Mapping[
typing.Any,
typing.Union[
ExceptionHandler, typing.Callable[[Request, Exception], Response]
],
]
] = None,
debug: bool = False,
) -> None:
Expand All @@ -31,11 +61,31 @@ def __init__(
for key, value in handlers.items():
self.add_exception_handler(key, value)

@typing.overload
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: ExceptionHandler,
) -> None:
...

@typing.overload
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
...

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Union[
ExceptionHandler, typing.Callable[[typing.Any, typing.Any], typing.Any]
],
) -> None:
if isinstance(handler, ExceptionHandler):
handler = handler.handler
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
Expand All @@ -44,7 +94,7 @@ def add_exception_handler(

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
) -> typing.Optional[typing.Callable[[typing.Any, typing.Any], typing.Any]]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
Expand Down
33 changes: 18 additions & 15 deletions tests 8000 /test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,28 @@
import pytest

from starlette import status
from starlette.applications import Starlette
from starlette.applications import ExceptionHandler, Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.websockets import WebSocket


async def error_500(request, exc):
async def error_500(request: Request, exc: Exception) -> Response:
return JSONResponse({"detail": "Server Error"}, status_code=500)


async def method_not_allowed(request, exc):
async def method_not_allowed(request: Request, exc: Exception) -> Response:
return JSONResponse({"detail": "Custom message"}, status_code=405)


async def http_exception(request, exc):
async def http_exception(request: Request, exc: HTTPException) -> Response:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)


Expand Down Expand Up @@ -67,7 +68,7 @@ async def websocket_endpoint(session):
await session.close()


async def websocket_raise_websocket(websocket: WebSocket):
async def websocket_raise_websocket(websocket: WebSocket) -> None:
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)

Expand All @@ -76,12 +77,12 @@ class CustomWSException(Exception):
pass


async def websocket_raise_custom(websocket: WebSocket):
async def websocket_raise_custom(websocket: WebSocket) -> None:
await websocket.accept()
raise CustomWSException()


def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)


Expand All @@ -99,10 +100,10 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
)

exception_handlers = {
500: error_500,
405: method_not_allowed,
HTTPException: http_exception,
CustomWSException: custom_ws_exception_handler,
500: ExceptionHandler(error_500),
405: ExceptionHandler(method_not_allowed),
HTTPException: ExceptionHandler(http_exception),
CustomWSException: ExceptionHandler(custom_ws_exception_handler),
}

middleware = [
Expand Down Expand Up @@ -491,12 +492,14 @@ async def startup():
assert len(record) == 1


def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
def test_middleware_stack_init(
test_client_factory: Callable[[ASGIApp], httpx.Client]
) -> None:
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
async def __call__(self, *args: Any) -> None:
await self.app(*args)

class SimpleInitializableMiddleware:
Expand All @@ -506,7 +509,7 @@ def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
async def __call__(self, *args: Any) -> None:
await self.app(*args)

def get_app() -> ASGIApp:
Expand Down
0