From 3a4a8025418ba36a3fc57f0b1aced7fed0c1ca92 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Mon, 30 Sep 2024 20:34:00 +0400 Subject: [PATCH 01/57] feat: preparing for next major release --- .pre-commit-config.yaml | 2 +- aioauth/errors.py | 39 +++++++++--------- aioauth/grant_type.py | 68 +++++++++++++++++-------------- aioauth/models.py | 11 ++--- aioauth/oidc/core/grant_type.py | 15 ++++--- aioauth/oidc/core/requests.py | 3 -- aioauth/requests.py | 19 +++++---- aioauth/response_type.py | 49 ++++++++++++----------- aioauth/server.py | 71 +++++++++++++++++---------------- aioauth/storage.py | 44 ++++++++++---------- aioauth/types.py | 2 +- tests/classes.py | 49 +++++++++++++---------- tests/conftest.py | 10 +++-- tests/factories.py | 41 +++++++++---------- tests/test_db.py | 9 +++-- tests/test_endpoint.py | 24 +++++------ tests/test_flow.py | 8 ++-- tests/test_grant_type.py | 4 +- tests/test_request_validator.py | 10 ++--- 19 files changed, 247 insertions(+), 231 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85495a1..a2a7d07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: rev: v0.950 hooks: - id: mypy - exclude: ^(docs/|setup\.py|tests/) + exclude: ^(docs/|setup\.py) - repo: https://github.com/pycqa/flake8 rev: 4.0.1 diff --git a/aioauth/errors.py b/aioauth/errors.py index 9e96601..90be9aa 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -9,17 +9,18 @@ """ from http import HTTPStatus -from typing import Generic, Optional +from typing import Optional from urllib.parse import urljoin from typing_extensions import Literal +from .requests import Request + from .collections import HTTPHeaderDict from .constances import default_headers -from .requests import TRequest from .types import ErrorType -class OAuth2Error(Exception, Generic[TRequest]): +class OAuth2Error(Exception): """Base exception that all other exceptions inherit from.""" error: ErrorType @@ -31,7 +32,7 @@ class OAuth2Error(Exception, Generic[TRequest]): def __init__( self, - request: TRequest, + request: Request, description: Optional[str] = None, headers: Optional[HTTPHeaderDict] = None, state: Optional[str] = None, @@ -53,7 +54,7 @@ def __init__( super().__init__(f"({self.error}) {self.description}") -class MethodNotAllowedError(OAuth2Error[TRequest]): +class MethodNotAllowedError(OAuth2Error): """ The request is valid, but the method trying to be accessed is not available to the resource owner. @@ -64,7 +65,7 @@ class MethodNotAllowedError(OAuth2Error[TRequest]): error: ErrorType = "method_is_not_allowed" -class InvalidRequestError(OAuth2Error[TRequest]): +class InvalidRequestError(OAuth2Error): """ The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is @@ -74,7 +75,7 @@ class InvalidRequestError(OAuth2Error[TRequest]): error: Literal["invalid_request"] = "invalid_request" -class InvalidClientError(OAuth2Error[TRequest]): +class InvalidClientError(OAuth2Error): """ Client authentication failed (e.g. unknown client, no client authentication included, or unsupported authentication method). @@ -92,7 +93,7 @@ class InvalidClientError(OAuth2Error[TRequest]): def __init__( self, - request: TRequest, + request: Request, description: Optional[str] = None, headers: Optional[HTTPHeaderDict] = None, state: Optional[str] = None, @@ -109,14 +110,14 @@ def __init__( self.headers["WWW-Authenticate"] = "Basic " + ", ".join(auth_values) -class InsecureTransportError(OAuth2Error[TRequest]): +class InsecureTransportError(OAuth2Error): """An exception will be thrown if the current request is not secure.""" description = "OAuth 2 MUST utilize https." error: ErrorType = "insecure_transport" -class UnsupportedGrantTypeError(OAuth2Error[TRequest]): +class UnsupportedGrantTypeError(OAuth2Error): """ The authorization grant type is not supported by the authorization server. @@ -125,7 +126,7 @@ class UnsupportedGrantTypeError(OAuth2Error[TRequest]): error: ErrorType = "unsupported_grant_type" -class UnsupportedResponseTypeError(OAuth2Error[TRequest]): +class UnsupportedResponseTypeError(OAuth2Error): """ The authorization server does not support obtaining an authorization code using this method. @@ -134,7 +135,7 @@ class UnsupportedResponseTypeError(OAuth2Error[TRequest]): error: ErrorType = "unsupported_response_type" -class InvalidGrantError(OAuth2Error[TRequest]): +class InvalidGrantError(OAuth2Error): """ The provided authorization grant (e.g. authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does @@ -147,14 +148,14 @@ class InvalidGrantError(OAuth2Error[TRequest]): error: ErrorType = "invalid_grant" -class MismatchingStateError(OAuth2Error[TRequest]): +class MismatchingStateError(OAuth2Error): """Unable to securely verify the integrity of the request and response.""" description = "CSRF Warning! State not equal in request and response." error: Literal["mismatching_state"] = "mismatching_state" -class UnauthorizedClientError(OAuth2Error[TRequest]): +class UnauthorizedClientError(OAuth2Error): """ The authenticated client is not authorized to use this authorization grant type. @@ -163,7 +164,7 @@ class UnauthorizedClientError(OAuth2Error[TRequest]): error: ErrorType = "unauthorized_client" -class InvalidScopeError(OAuth2Error[TRequest]): +class InvalidScopeError(OAuth2Error): """ The requested scope is invalid, unknown, or malformed, or exceeds the scope granted by the resource owner. @@ -174,7 +175,7 @@ class InvalidScopeError(OAuth2Error[TRequest]): error: ErrorType = "invalid_scope" -class ServerError(OAuth2Error[TRequest]): +class ServerError(OAuth2Error): """ The authorization server encountered an unexpected condition that prevented it from fulfilling the request. (This error code is needed @@ -185,7 +186,7 @@ class ServerError(OAuth2Error[TRequest]): error: ErrorType = "server_error" -class TemporarilyUnavailableError(OAuth2Error[TRequest]): +class TemporarilyUnavailableError(OAuth2Error): """ The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server. @@ -196,7 +197,7 @@ class TemporarilyUnavailableError(OAuth2Error[TRequest]): error: ErrorType = "temporarily_unavailable" -class InvalidRedirectURIError(OAuth2Error[TRequest]): +class InvalidRedirectURIError(OAuth2Error): """ The requested redirect URI is missing or not allowed. """ @@ -204,7 +205,7 @@ class InvalidRedirectURIError(OAuth2Error[TRequest]): error: ErrorType = "invalid_request" -class UnsupportedTokenTypeError(OAuth2Error[TRequest]): +class UnsupportedTokenTypeError(OAuth2Error): """ The authorization server does not support the revocation of the presented token type. That is, the client tried to revoke an access token on a server diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 51e80ee..44eb154 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -8,6 +8,9 @@ ---- """ from typing import Generic, Optional + +from .requests import Request, TUser +from .storage import BaseStorage from .errors import ( InvalidClientError, InvalidGrantError, @@ -18,23 +21,23 @@ UnauthorizedClientError, ) from .models import Client -from .requests import TRequest from .responses import TokenResponse -from .storage import TStorage from .utils import enforce_list, enforce_str, generate_token -class GrantTypeBase(Generic[TRequest, TStorage]): +class GrantTypeBase(Generic[TUser]): """Base grant type that all other grant types inherit from.""" - def __init__(self, storage: TStorage, client_id: str, client_secret: Optional[str]): + def __init__( + self, storage: BaseStorage[TUser], client_id: str, client_secret: Optional[str] + ): self.storage = storage self.client_id = client_id self.client_secret = client_secret self.scope: Optional[str] = None async def create_token_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> TokenResponse: """Creates token response to reply to client.""" if self.scope is None: @@ -57,28 +60,28 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: """Validates the client request to ensure it is valid.""" client = await self.storage.get_client( request, client_id=self.client_id, client_secret=self.client_secret ) if not client: - raise InvalidClientError[TRequest]( + raise InvalidClientError( request=request, description="Invalid client_id parameter value." ) if not client.check_grant_type(request.post.grant_type): - raise UnauthorizedClientError[TRequest](request=request) + raise UnauthorizedClientError(request=request) if not client.check_scope(request.post.scope): - raise InvalidScopeError[TRequest](request=request) + raise InvalidScopeError(request=request) self.scope = request.post.scope return client -class AuthorizationCodeGrantType(GrantTypeBase[TRequest, TStorage]): +class AuthorizationCodeGrantType(GrantTypeBase[TUser]): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -94,21 +97,21 @@ class AuthorizationCodeGrantType(GrantTypeBase[TRequest, TStorage]): See `RFC 6749 section 1.3.1 `_. """ - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.redirect_uri: - raise InvalidRedirectURIError[TRequest]( + raise InvalidRedirectURIError( request=request, description="Mismatching redirect URI." ) if not client.check_redirect_uri(request.post.redirect_uri): - raise InvalidRedirectURIError[TRequest]( + raise InvalidRedirectURIError( request=request, description="Invalid redirect URI." ) if not request.post.code: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Missing code parameter." ) @@ -117,14 +120,14 @@ async def validate_request(self, request: TRequest) -> Client: ) if not authorization_code: - raise InvalidGrantError[TRequest](request=request) + raise InvalidGrantError(request=request) if ( authorization_code.code_challenge and authorization_code.code_challenge_method ): if not request.post.code_verifier: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Code verifier required." ) @@ -132,19 +135,22 @@ async def validate_request(self, request: TRequest) -> Client: request.post.code_verifier ) if not is_valid_code_challenge: - raise MismatchingStateError[TRequest](request=request) + raise MismatchingStateError(request=request) if authorization_code.is_expired: - raise InvalidGrantError[TRequest](request=request) + raise InvalidGrantError(request=request) self.scope = authorization_code.scope return client async def create_token_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> TokenResponse: token_response = await super().create_token_response(request, client) + if request.post.code is None: + raise + await self.storage.delete_authorization_code( request, client.client_id, @@ -154,7 +160,7 @@ async def create_token_response( return token_response -class PasswordGrantType(GrantTypeBase[TRequest, TStorage]): +class PasswordGrantType(GrantTypeBase[TUser]): """ The Password grant type is a way to exchange a user's credentials for an access token. Because the client application has to collect @@ -165,25 +171,25 @@ class PasswordGrantType(GrantTypeBase[TRequest, TStorage]): disallows the password grant entirely. """ - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.username or not request.post.password: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Invalid credentials given." ) user = await self.storage.authenticate(request) if not user: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Invalid credentials given." ) return client -class RefreshTokenGrantType(GrantTypeBase[TRequest, TStorage]): +class RefreshTokenGrantType(GrantTypeBase[TUser]): """ The Refresh Token grant type is used by clients to exchange a refresh token for an access token when the access token has expired. @@ -193,7 +199,7 @@ class RefreshTokenGrantType(GrantTypeBase[TRequest, TStorage]): """ async def create_token_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> TokenResponse: """Validate token request and create token response.""" old_token = await self.storage.get_token( @@ -203,7 +209,7 @@ async def create_token_response( ) if not old_token or old_token.revoked or old_token.refresh_token_expired: - raise InvalidGrantError[TRequest](request=request) + raise InvalidGrantError(request=request) # Revoke old token await self.storage.revoke_token( @@ -235,18 +241,18 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.refresh_token: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Missing refresh token parameter." ) return client -class ClientCredentialsGrantType(GrantTypeBase[TRequest, TStorage]): +class ClientCredentialsGrantType(GrantTypeBase[TUser]): """ The Client Credentials grant type is used by clients to obtain an access token outside of the context of a user. This is typically @@ -255,9 +261,9 @@ class ClientCredentialsGrantType(GrantTypeBase[TRequest, TStorage]): See `RFC 6749 section 4.4 `_. """ - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: # client_credentials grant requires a client_secret if self.client_secret is None: - raise InvalidClientError[TRequest](request) + raise InvalidClientError(request) return await super().validate_request(request) diff --git a/aioauth/models.py b/aioauth/models.py index b40f5bf..a6a496c 100644 --- a/aioauth/models.py +++ b/aioauth/models.py @@ -9,9 +9,9 @@ """ from dataclasses import dataclass import time -from typing import Any, List, Optional, TypeVar, Union +from typing import Any, List, Optional, Union -from .types import CodeChallengeMethod, GrantType, ResponseType +from .types import CodeChallengeMethod, GrantType, ResponseType, TokenType from .utils import create_s256_code_challenge, enforce_list, enforce_str @@ -254,7 +254,7 @@ class Token: clients that the authorization server handles. """ - token_type: str = "Bearer" + token_type: TokenType = "Bearer" """ Type of token expected. """ @@ -278,8 +278,3 @@ def is_expired(self) -> bool: def refresh_token_expired(self) -> bool: """Checks if refresh token has expired.""" return (self.issued_at + self.refresh_token_expires_in) < time.time() - - -TToken = TypeVar("TToken", bound=Token) -TClient = TypeVar("TClient", bound=Client) -TAuthorizationCode = TypeVar("TAuthorizationCode", bound=AuthorizationCode) diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index 45be00e..ed453bf 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -9,17 +9,16 @@ """ from typing import TYPE_CHECKING -from aioauth.grant_type import ( +from ...grant_type import ( AuthorizationCodeGrantType as OAuth2AuthorizationCodeGrantType, ) -from aioauth.models import Client -from aioauth.oidc.core.responses import TokenResponse -from aioauth.oidc.core.requests import TRequest -from aioauth.storage import TStorage -from aioauth.utils import generate_token +from ...models import Client +from ...oidc.core.responses import TokenResponse +from ...requests import Request, TUser +from ...utils import generate_token -class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TRequest, TStorage]): +class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TUser]): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -36,7 +35,7 @@ class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TRequest, TSto """ async def create_token_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> TokenResponse: """ Creates token response to reply to client. diff --git a/aioauth/oidc/core/requests.py b/aioauth/oidc/core/requests.py index 88bdb8c..03d1f94 100644 --- a/aioauth/oidc/core/requests.py +++ b/aioauth/oidc/core/requests.py @@ -37,9 +37,6 @@ class BaseRequest(BaseOAuth2Request[TQuery, TPost, TUser]): user: Optional[TUser] = None -TRequest = TypeVar("TRequest", bound=BaseRequest) - - @dataclass class Request(BaseRequest[Query, Post, Any]): """Object that contains a client's complete request.""" diff --git a/aioauth/requests.py b/aioauth/requests.py index 7f2f84c..b7c4f52 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -8,11 +8,17 @@ ---- """ from dataclasses import dataclass, field -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, Optional, TypeVar from .collections import HTTPHeaderDict from .config import Settings -from .types import CodeChallengeMethod, GrantType, RequestMethod, ResponseMode +from .types import ( + CodeChallengeMethod, + GrantType, + RequestMethod, + ResponseMode, + TokenType, +) @dataclass @@ -50,7 +56,7 @@ class Post: refresh_token: Optional[str] = None code: Optional[str] = None token: Optional[str] = None - token_type_hint: Optional[str] = None + token_type_hint: Optional[TokenType] = None code_verifier: Optional[str] = None @@ -70,13 +76,10 @@ class BaseRequest(Generic[TQuery, TPost, TUser]): settings: Settings = field(default_factory=Settings) -TRequest = TypeVar("TRequest", bound=BaseRequest) - - @dataclass -class Request(BaseRequest[Query, Post, Any]): +class Request(Generic[TUser], BaseRequest[Query, Post, TUser]): """Object that contains a client's complete request.""" query: Query = field(default_factory=Query) post: Post = field(default_factory=Post) - user: Optional[Any] = None + user: Optional[TUser] = None diff --git a/aioauth/response_type.py b/aioauth/response_type.py index e2e975a..3db01fc 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -10,6 +10,9 @@ import sys from typing import Generic, Tuple +from .requests import Request, TUser +from .storage import BaseStorage + if sys.version_info >= (3, 8): from typing import get_args else: @@ -24,24 +27,22 @@ UnsupportedResponseTypeError, ) from .models import Client -from .requests import TRequest from .responses import ( AuthorizationCodeResponse, IdTokenResponse, NoneResponse, TokenResponse, ) -from .storage import TStorage from .types import CodeChallengeMethod -class ResponseTypeBase(Generic[TRequest, TStorage]): +class ResponseTypeBase(Generic[TUser]): """Base response type that all other exceptions inherit from.""" - def __init__(self, storage: TStorage): + def __init__(self, storage: BaseStorage[TUser]): self.storage = storage - async def validate_request(self, request: TRequest) -> Client: + async def validate_request(self, request: Request) -> Client: state = request.query.state code_challenge_methods: Tuple[CodeChallengeMethod, ...] = get_args( @@ -49,7 +50,7 @@ async def validate_request(self, request: TRequest) -> Client: ) if not request.query.client_id: - raise InvalidClientError[TRequest]( + raise InvalidClientError( request=request, description="Missing client_id parameter.", state=state ) @@ -58,54 +59,54 @@ async def validate_request(self, request: TRequest) -> Client: ) if not client: - raise InvalidClientError[TRequest]( + raise InvalidClientError( request=request, description="Invalid client_id parameter value.", state=state, ) if not request.query.redirect_uri: - raise InvalidRedirectURIError[TRequest]( + raise InvalidRedirectURIError( request=request, description="Mismatching redirect URI.", state=state ) if not client.check_redirect_uri(request.query.redirect_uri): - raise InvalidRedirectURIError[TRequest]( + raise InvalidRedirectURIError( request=request, description="Invalid redirect URI.", state=state ) if request.query.code_challenge_method: if request.query.code_challenge_method not in code_challenge_methods: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Transform algorithm not supported.", state=state, ) if not request.query.code_challenge: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Code challenge required.", state=state ) if not client.check_response_type(request.query.response_type): - raise UnsupportedResponseTypeError[TRequest](request=request, state=state) + raise UnsupportedResponseTypeError(request=request, state=state) if not client.check_scope(request.query.scope): - raise InvalidScopeError[TRequest](request=request, state=state) + raise InvalidScopeError(request=request, state=state) if not request.user: - raise InvalidClientError[TRequest]( + raise InvalidClientError( request=request, description="User is not authorized", state=state ) return client -class ResponseTypeToken(ResponseTypeBase[TRequest, TStorage]): +class ResponseTypeToken(ResponseTypeBase[TUser]): """Response type that contains a token.""" async def create_authorization_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> TokenResponse: token = await self.storage.create_token( request, @@ -124,11 +125,11 @@ async def create_authorization_response( ) -class ResponseTypeAuthorizationCode(ResponseTypeBase[TRequest, TStorage]): +class ResponseTypeAuthorizationCode(ResponseTypeBase[TUser]): """Response type that contains an authorization code.""" async def create_authorization_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> AuthorizationCodeResponse: authorization_code = await self.storage.create_authorization_code( client_id=client.client_id, @@ -147,13 +148,13 @@ async def create_authorization_response( ) -class ResponseTypeIdToken(ResponseTypeBase[TRequest, TStorage]): - async def validate_request(self, request: TRequest) -> Client: +class ResponseTypeIdToken(ResponseTypeBase[TUser]): + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) # nonce is required for id_token if not request.query.nonce: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Nonce required for response_type id_token.", state=request.query.state, @@ -161,7 +162,7 @@ async def validate_request(self, request: TRequest) -> Client: return client async def create_authorization_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> IdTokenResponse: id_token = await self.storage.get_id_token( request, @@ -175,8 +176,8 @@ async def create_authorization_response( return IdTokenResponse(id_token=id_token) -class ResponseTypeNone(ResponseTypeBase[TRequest, TStorage]): +class ResponseTypeNone(ResponseTypeBase[TUser]): async def create_authorization_response( - self, request: TRequest, client: Client + self, request: Request, client: Client ) -> NoneResponse: return NoneResponse() diff --git a/aioauth/server.py b/aioauth/server.py index b16f9af..47aee18 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -21,6 +21,9 @@ from http import HTTPStatus from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union +from .requests import Request, TUser +from .storage import BaseStorage + if sys.version_info >= (3, 8): from typing import get_args @@ -47,7 +50,6 @@ PasswordGrantType, RefreshTokenGrantType, ) -from .requests import TRequest from .response_type import ( ResponseTypeAuthorizationCode, ResponseTypeIdToken, @@ -59,7 +61,6 @@ TokenActiveIntrospectionResponse, TokenInactiveIntrospectionResponse, ) -from .storage import TStorage from .types import ( GrantType, RequestMethod, @@ -74,25 +75,25 @@ ) -class AuthorizationServer(Generic[TRequest, TStorage]): +class AuthorizationServer(Generic[TUser]): """Interface for initializing an OAuth 2.0 server.""" response_types: Dict[ResponseType, Any] = { - "token": ResponseTypeToken[TRequest, TStorage], - "code": ResponseTypeAuthorizationCode[TRequest, TStorage], - "none": ResponseTypeNone[TRequest, TStorage], - "id_token": ResponseTypeIdToken[TRequest, TStorage], + "token": ResponseTypeToken[TUser], + "code": ResponseTypeAuthorizationCode[TUser], + "none": ResponseTypeNone[TUser], + "id_token": ResponseTypeIdToken[TUser], } grant_types: Dict[GrantType, Any] = { - "authorization_code": AuthorizationCodeGrantType[TRequest, TStorage], - "client_credentials": ClientCredentialsGrantType[TRequest, TStorage], - "password": PasswordGrantType[TRequest, TStorage], - "refresh_token": RefreshTokenGrantType[TRequest, TStorage], + "authorization_code": AuthorizationCodeGrantType[TUser], + "client_credentials": ClientCredentialsGrantType[TUser], + "password": PasswordGrantType[TUser], + "refresh_token": RefreshTokenGrantType[TUser], } def __init__( self, - storage: TStorage, + storage: BaseStorage[TUser], response_types: Optional[Dict] = None, grant_types: Optional[Dict] = None, ): @@ -104,7 +105,7 @@ def __init__( if grant_types is not None: self.grant_types = grant_types - def is_secure_transport(self, request: TRequest) -> bool: + def is_secure_transport(self, request: Request) -> bool: """ Verifies the request was sent via a protected SSL tunnel. @@ -121,21 +122,21 @@ def is_secure_transport(self, request: TRequest) -> bool: return True return request.url.lower().startswith("https://") - def validate_request(self, request: TRequest, allowed_methods: List[RequestMethod]): + def validate_request(self, request: Request, allowed_methods: List[RequestMethod]): if not request.settings.AVAILABLE: - raise TemporarilyUnavailableError[TRequest](request=request) + raise TemporarilyUnavailableError(request=request) if not self.is_secure_transport(request): - raise InsecureTransportError[TRequest](request=request) + raise InsecureTransportError(request=request) if request.method not in allowed_methods: headers = HTTPHeaderDict( {**default_headers, "allow": ", ".join(allowed_methods)} ) - raise MethodNotAllowedError[TRequest](request=request, headers=headers) + raise MethodNotAllowedError(request=request, headers=headers) @catch_errors_and_unavailability() - async def create_token_introspection_response(self, request: TRequest) -> Response: + async def create_token_introspection_response(self, request: Request) -> Response: """ Returns a response object with introspection of the passed token. For more information see `RFC7662 section 2.1 `_. @@ -176,7 +177,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError[TRequest](request) + raise InvalidClientError(request) token_types: Tuple[TokenType, ...] = get_args(TokenType) token_type: TokenType = "refresh_token" @@ -220,7 +221,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) def get_client_credentials( - self, request: TRequest, secret_required: bool + self, request: Request, secret_required: bool ) -> Tuple[str, str]: client_id = request.post.client_id client_secret = request.post.client_secret @@ -235,7 +236,7 @@ def get_client_credentials( if client_id is None or secret_required: # Either we didn't find a client ID at all, or we found # a client ID but no secret and a secret is required. - raise InvalidClientError[TRequest]( + raise InvalidClientError( description="Invalid client_id parameter value.", request=request, ) from exc @@ -248,7 +249,7 @@ def get_client_credentials( return client_id, client_secret @catch_errors_and_unavailability() - async def create_token_response(self, request: TRequest) -> Response: + async def create_token_response(self, request: Request) -> Response: """Endpoint to obtain an access and/or ID token by presenting an authorization grant or refresh token. Validates a token request and creates a token response. @@ -300,17 +301,17 @@ async def token(request: fastapi.Request) -> fastapi.Response: if not request.post.grant_type: # grant_type request value is empty - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Request is missing grant type." ) GrantTypeClass: Type[ Union[ - GrantTypeBase[TRequest, TStorage], - AuthorizationCodeGrantType[TRequest, TStorage], - PasswordGrantType[TRequest, TStorage], - RefreshTokenGrantType[TRequest, TStorage], - ClientCredentialsGrantType[TRequest, TStorage], + GrantTypeBase[TUser], + AuthorizationCodeGrantType[TUser], + PasswordGrantType[TUser], + RefreshTokenGrantType[TUser], + ClientCredentialsGrantType[TUser], ] ] @@ -318,7 +319,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: GrantTypeClass = self.grant_types[request.post.grant_type] except KeyError as exc: # grant_type request value is invalid - raise UnsupportedGrantTypeError[TRequest](request=request) from exc + raise UnsupportedGrantTypeError(request=request) from exc grant_type = GrantTypeClass( storage=self.storage, client_id=client_id, client_secret=client_secret @@ -340,7 +341,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: InvalidRedirectURIError, ) ) - async def create_authorization_response(self, request: TRequest) -> Response: + async def create_authorization_response(self, request: Request) -> Response: """ Endpoint to interact with the resource owner and obtain an authorization grant. @@ -394,7 +395,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: state = request.query.state if not response_type_list: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Missing response_type parameter.", state=state, @@ -463,7 +464,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: ) @catch_errors_and_unavailability() - async def revoke_token(self, request: TRequest) -> Response: + async def revoke_token(self, request: Request) -> Response: """Endpoint to revoke an access token or refresh token. For more information see `RFC7009 `_. @@ -503,10 +504,10 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError[TRequest](request) + raise InvalidClientError(request) if not request.post.token: - raise InvalidRequestError[TRequest]( + raise InvalidRequestError( request=request, description="Request is missing token." ) @@ -514,7 +515,7 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: "refresh_token", "access_token", }: - raise UnsupportedTokenTypeError[TRequest](request=request) + raise UnsupportedTokenTypeError(request=request) access_token = ( request.post.token diff --git a/aioauth/storage.py b/aioauth/storage.py index d6d6c45..76260c0 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -10,22 +10,23 @@ ---- """ -from typing import Optional, Generic, TypeVar +from typing import Optional, Generic + +from .models import AuthorizationCode, Client, Token from .types import CodeChallengeMethod, ResponseType, TokenType -from .models import TToken, TClient, TAuthorizationCode -from .requests import TRequest +from .requests import Request, TUser -class BaseStorage(Generic[TToken, TClient, TAuthorizationCode, TRequest]): +class BaseStorage(Generic[TUser]): async def create_token( self, - request: TRequest, + request: Request[TUser], client_id: str, scope: str, access_token: str, refresh_token: str, - ) -> TToken: + ) -> Token: """Generates a user token and stores it in the database. Warning: @@ -44,12 +45,12 @@ async def create_token( async def get_token( self, - request: TRequest, + request: Request, client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, refresh_token: Optional[str] = None, - ) -> Optional[TToken]: + ) -> Optional[Token]: """Gets existing token from the database. Note: @@ -68,7 +69,7 @@ async def get_token( async def create_authorization_code( self, - request: TRequest, + request: Request, client_id: str, scope: str, response_type: ResponseType, @@ -77,7 +78,7 @@ async def create_authorization_code( code_challenge: Optional[str], code: str, **kwargs, - ) -> TAuthorizationCode: + ) -> AuthorizationCode: """Generates an authorization token and stores it in the database. Warning: @@ -102,11 +103,12 @@ async def create_authorization_code( async def get_id_token( self, - request: TRequest, + request: Request, client_id: str, scope: str, response_type: ResponseType, redirect_uri: str, + nonce: Optional[str], **kwargs, ) -> str: """Returns an id_token. @@ -119,8 +121,11 @@ async def get_id_token( raise NotImplementedError("get_id_token must be implemented.") async def get_client( - self, request: TRequest, client_id: str, client_secret: Optional[str] = None - ) -> Optional[TClient]: + self, + request: Request[TUser], + client_id: str, + client_secret: Optional[str] = None, + ) -> Optional[Client]: """Gets existing client from the database if it exists. Warning: @@ -139,7 +144,7 @@ async def get_client( """ raise NotImplementedError("Method get_client must be implemented") - async def authenticate(self, request: TRequest) -> bool: + async def authenticate(self, request: Request[TUser]) -> bool: """Authenticates a user. Note: @@ -154,8 +159,8 @@ async def authenticate(self, request: TRequest) -> bool: raise NotImplementedError("Method authenticate must be implemented") async def get_authorization_code( - self, request: TRequest, client_id: str, code: str - ) -> Optional[TAuthorizationCode]: + self, request: Request[TUser], client_id: str, code: str + ) -> Optional[AuthorizationCode]: """Gets existing authorization code from the database if it exists. Warning: @@ -177,7 +182,7 @@ async def get_authorization_code( ) async def delete_authorization_code( - self, request: TRequest, client_id: str, code: str + self, request: Request[TUser], client_id: str, code: str ) -> None: """Deletes authorization code from database. @@ -195,7 +200,7 @@ async def delete_authorization_code( async def revoke_token( self, - request: TRequest, + request: Request[TUser], token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, refresh_token: Optional[str] = None, @@ -213,6 +218,3 @@ async def revoke_token( raise NotImplementedError( "Method revoke_token must be implemented for RefreshTokenGrantType" ) - - -TStorage = TypeVar("TStorage", bound=BaseStorage) diff --git a/aioauth/types.py b/aioauth/types.py index 75c6fff..a5c6cf7 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -64,4 +64,4 @@ ] -TokenType = Literal["access_token", "refresh_token"] +TokenType = Literal["access_token", "refresh_token", "Bearer"] diff --git a/tests/classes.py b/tests/classes.py index f36cdec..caa0d7a 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -1,13 +1,15 @@ import time import sys -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional, Type from dataclasses import replace, dataclass from aioauth.config import Settings +from aioauth.grant_type import GrantTypeBase from aioauth.models import AuthorizationCode, Client, Token -from aioauth.requests import BaseRequest, Post, Query, TRequest +from aioauth.requests import Request +from aioauth.response_type import ResponseTypeBase from aioauth.server import AuthorizationServer from aioauth.storage import BaseStorage from aioauth.types import CodeChallengeMethod, GrantType, ResponseType, TokenType @@ -24,12 +26,12 @@ class User: last_name: str -@dataclass -class Request(BaseRequest[Query, Post, User]): - ... +# @dataclass +# class Request(BaseRequest[Query, Post, User]): +# ... -class Storage(BaseStorage[Token, Client, AuthorizationCode, Request]): +class Storage(BaseStorage[User]): def __init__( self, authorization_codes: List[AuthorizationCode], @@ -53,7 +55,10 @@ def _get_by_client_id(self, client_id: str): return client async def get_client( - self, request: Request, client_id: str, client_secret: Optional[str] = None + self, + request: Request[User], + client_id: str, + client_secret: Optional[str] = None, ) -> Optional[Client]: if client_secret is not None: return self._get_by_client_secret(client_id, client_secret) @@ -62,7 +67,7 @@ async def get_client( async def create_token( self, - request: Request, + request: Request[User], client_id: str, scope: str, access_token: str, @@ -83,7 +88,7 @@ async def create_token( async def revoke_token( self, - request: Request, + request: Request[User], token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, refresh_token: Optional[str] = None, @@ -97,7 +102,7 @@ async def revoke_token( async def get_token( self, - request: Request, + request: Request[User], client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, @@ -117,14 +122,14 @@ async def get_token( ): return token_ - async def authenticate(self, request: Request) -> bool: + async def authenticate(self, request: Request[User]) -> bool: password = request.post.password username = request.post.username return username in self.users and self.users[username] == password async def create_authorization_code( self, - request: Request, + request: Request[User], client_id: str, scope: str, response_type: str, @@ -152,7 +157,7 @@ async def create_authorization_code( return authorization_code async def get_authorization_code( - self, request: Request, client_id: str, code: str + self, request: Request[User], client_id: str, code: str ) -> Optional[AuthorizationCode]: for authorization_code in self.authorization_codes: if ( @@ -163,7 +168,7 @@ async def get_authorization_code( async def delete_authorization_code( self, - request: Request, + request: Request[User], client_id: str, code: str, ): @@ -177,12 +182,12 @@ async def delete_authorization_code( async def get_id_token( self, - request: Request, + request: Request[User], client_id: str, scope: str, - response_type: str, + response_type: ResponseType, redirect_uri: str, - nonce: str, + nonce: Optional[str], **kwargs, ) -> str: return "generated id token" @@ -192,12 +197,14 @@ class AuthorizationContext: def __init__( self, clients: Optional[List[Client]] = None, - grant_types: Optional[Dict[GrantType, Any]] = None, + grant_types: Optional[Dict[GrantType, Type[GrantTypeBase[User]]]] = None, initial_authorization_codes: Optional[List[AuthorizationCode]] = None, initial_tokens: Optional[List[Token]] = None, - response_types: Optional[Dict[ResponseType, Any]] = None, + response_types: Optional[ + Dict[ResponseType, Type[ResponseTypeBase[User]]] + ] = None, settings: Optional[Settings] = None, - users: Dict[str, str] = None, + users: Optional[Dict[str, str]] = None, ): self.initial_authorization_codes = initial_authorization_codes or [] self.initial_tokens = initial_tokens or [] @@ -209,7 +216,7 @@ def __init__( self.users = users or {} @cached_property - def server(self) -> AuthorizationServer[TRequest, Storage]: + def server(self) -> AuthorizationServer[User]: return AuthorizationServer( grant_types=self.grant_types, response_types=self.response_types, diff --git a/tests/conftest.py b/tests/conftest.py index e2b4cbe..e032f22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,10 @@ +from typing import Any, Generator import pytest -from aioauth.requests import Request from aioauth.server import AuthorizationServer from tests import factories -from tests.classes import AuthorizationContext, Storage +from tests.classes import AuthorizationContext, User @pytest.fixture @@ -13,10 +13,12 @@ def context_factory(): @pytest.fixture -def context() -> AuthorizationContext: +def context() -> Generator[AuthorizationContext, Any, Any]: yield factories.context_factory() @pytest.fixture -def server(context) -> AuthorizationServer[Request, Storage]: +def server( + context: AuthorizationContext, +) -> Generator[AuthorizationServer[User], Any, Any]: yield context.server diff --git a/tests/factories.py b/tests/factories.py index 318161d..a458aec 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,25 +1,26 @@ import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from aioauth.config import Settings from aioauth.grant_type import ( AuthorizationCodeGrantType, ClientCredentialsGrantType, + GrantTypeBase, PasswordGrantType, RefreshTokenGrantType, ) from aioauth.models import AuthorizationCode, Client, Token -from aioauth.requests import Request from aioauth.response_type import ( ResponseTypeAuthorizationCode, + ResponseTypeBase, ResponseTypeIdToken, ResponseTypeNone, ResponseTypeToken, ) -from aioauth.types import GrantType, ResponseType +from aioauth.types import CodeChallengeMethod, GrantType, ResponseType from aioauth.utils import generate_token -from tests.classes import AuthorizationContext, Storage +from tests.classes import AuthorizationContext, User def access_token_factory() -> str: @@ -38,7 +39,7 @@ def client_secret_factory() -> str: return generate_token(48) -def authorization_code_factory() -> str: +def generate_code() -> str: return generate_token(5) @@ -46,21 +47,21 @@ def auth_time_factory() -> int: return int(time.time()) -def grant_types_factory() -> Dict[str, GrantType]: +def grant_types_factory() -> Dict[GrantType, Type[GrantTypeBase[User]]]: return { - "authorization_code": AuthorizationCodeGrantType[Request, Storage], - "client_credentials": ClientCredentialsGrantType[Request, Storage], - "password": PasswordGrantType[Request, Storage], - "refresh_token": RefreshTokenGrantType[Request, Storage], + "authorization_code": AuthorizationCodeGrantType[User], + "client_credentials": ClientCredentialsGrantType[User], + "password": PasswordGrantType[User], + "refresh_token": RefreshTokenGrantType[User], } -def response_types_factory() -> Dict[str, ResponseType]: +def response_types_factory() -> Dict[ResponseType, Type[ResponseTypeBase[User]]]: return { - "code": ResponseTypeAuthorizationCode[Request, Storage], - "id_token": ResponseTypeIdToken[Request, Storage], - "none": ResponseTypeNone[Request, Storage], - "token": ResponseTypeToken[Request, Storage], + "code": ResponseTypeAuthorizationCode[User], + "id_token": ResponseTypeIdToken[User], + "none": ResponseTypeNone[User], + "token": ResponseTypeToken[User], } @@ -97,8 +98,8 @@ def client_factory( def authorization_code_factory( auth_time: int = auth_time_factory(), client_id: str = client_id_factory(), - code: str = authorization_code_factory(), - code_challenge_method: str = "plain", + code: str = generate_code(), + code_challenge_method: CodeChallengeMethod = "plain", expires_in: int = 10, redirect_uri: str = "http://redirect.uri", response_type: str = "code", @@ -138,10 +139,10 @@ def token_factory( def context_factory( clients: Optional[List[Client]] = None, - grant_types: Optional[Dict[str, GrantType]] = None, + grant_types: Optional[Dict[GrantType, Type[GrantTypeBase[User]]]] = None, initial_authorization_codes: Optional[List[AuthorizationCode]] = None, initial_tokens: Optional[List[Token]] = None, - response_types: Optional[Dict[str, ResponseType]] = None, + response_types: Optional[Dict[ResponseType, Type[ResponseTypeBase[User]]]] = None, settings: Optional[Settings] = None, users: Optional[Dict[str, str]] = None, ) -> AuthorizationContext: @@ -152,7 +153,7 @@ def context_factory( _initial_authorization_codes = initial_authorization_codes or [ authorization_code_factory( client_id=client.client_id, - redirect_uri=client.redirect_uris if client.redirect_uris else "", + redirect_uri=client.redirect_uris[0] if client.redirect_uris else "", scope=client.scope, ) for client in _clients diff --git a/tests/test_db.py b/tests/test_db.py index 5b5b9c0..3838b20 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,3 +1,4 @@ +from typing import Any import pytest from aioauth.models import AuthorizationCode, Client, Token @@ -8,9 +9,9 @@ @pytest.mark.asyncio -async def test_storage_class(): - db = BaseStorage() - request = Request(method="POST") +async def test_storage_class() -> None: + db = BaseStorage[Any]() + request = Request[Any](method="POST") client: Client = factories.client_factory() token: Token = factories.token_factory() authorization_code: AuthorizationCode = factories.authorization_code_factory() @@ -29,7 +30,7 @@ async def test_storage_class(): request=request, client_id=client.client_id, scope="", - response_type="", + response_type="token", redirect_uri="", code_challenge_method=None, code_challenge=None, diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 3ba1efc..e44e8f4 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -12,11 +12,11 @@ ) from tests import factories -from tests.classes import AuthorizationContext +from tests.classes import AuthorizationContext, User @pytest.mark.asyncio -async def test_internal_server_error(): +async def test_internal_server_error() -> None: class EndpointClass: available: Optional[bool] = True @@ -43,7 +43,7 @@ async def test_invalid_token(context: AuthorizationContext): token = "invalid token" post = Post(token=token) - request = Request( + request = Request[User]( url=request_url, post=post, method="POST", @@ -99,7 +99,7 @@ async def test_valid_token(context: AuthorizationContext): server = context.server post = Post(token=token.refresh_token) - request = Request( + request = Request[User]( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -128,7 +128,7 @@ async def test_introspect_revoked_token(context: AuthorizationContext): grant_type="refresh_token", refresh_token=token.refresh_token, ) - request = Request( + request = Request[User]( settings=settings, url=request_url, post=post, @@ -139,7 +139,7 @@ async def test_introspect_revoked_token(context: AuthorizationContext): # Check that refreshed token was revoked post = Post(token=token.access_token, token_type_hint="access_token") - request = Request( + request = Request[User]( settings=settings, post=post, method="POST", @@ -171,7 +171,7 @@ async def test_introspect_token_with_wrong_client_secret(context: AuthorizationC server = context.server post = Post(token=token.refresh_token) - request = Request( + request = Request[User]( post=post, method="POST", headers=encode_auth_headers(client_id, f"not {client_secret}"), @@ -216,7 +216,7 @@ async def test_revoke_refresh_token(context: AuthorizationContext): server = context.server post = Post(token=token.refresh_token, token_type_hint="refresh_token") - request = Request( + request = Request[User]( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -227,7 +227,7 @@ async def test_revoke_refresh_token(context: AuthorizationContext): assert response.status_code == HTTPStatus.NO_CONTENT # Check that the token was revoked - request = Request( + request = Request[User]( settings=settings, post=post, method="POST", @@ -248,7 +248,7 @@ async def test_revoke_access_token(context: AuthorizationContext): server = context.server post = Post(token=token.access_token, token_type_hint="access_token") - request = Request( + request = Request[User]( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -259,7 +259,7 @@ async def test_revoke_access_token(context: AuthorizationContext): assert response.status_code == HTTPStatus.NO_CONTENT # Check that the token was revoked - request = Request( + request = Request[User]( settings=settings, post=post, method="POST", @@ -307,7 +307,7 @@ async def test_revoke_access_token_with_wrong_client_secret( server = context.server post = Post(token=token.access_token, token_type_hint="access_token") - request = Request( + request = Request[User]( post=post, method="POST", headers=encode_auth_headers(client_id, f"not {client_secret}"), diff --git a/tests/test_flow.py b/tests/test_flow.py index 39339d6..46a1422 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -13,7 +13,7 @@ ) from tests import factories -from tests.classes import AuthorizationContext +from tests.classes import AuthorizationContext, User from tests.utils import check_request_validators @@ -43,7 +43,7 @@ async def test_authorization_code_flow_plain_code_challenge(): scope=scope, ) - request = Request( + request = Request[User]( url=request_url, query=query, method="GET", @@ -465,7 +465,7 @@ async def test_client_credentials_flow_post_data(context: AuthorizationContext): scope=client.scope, ) - request = Request(url=request_url, post=post, method="POST") + request = Request[User](url=request_url, post=post, method="POST") await check_request_validators(request, server.create_token_response) @@ -484,7 +484,7 @@ async def test_client_credentials_flow_auth_header(context: AuthorizationContext scope=client.scope, ) - request = Request( + request = Request[User]( url=request_url, post=post, method="POST", diff --git a/tests/test_grant_type.py b/tests/test_grant_type.py index 8d7c9c0..a6871bb 100644 --- a/tests/test_grant_type.py +++ b/tests/test_grant_type.py @@ -5,7 +5,7 @@ from aioauth.requests import Post, Request from aioauth.utils import encode_auth_headers -from tests.classes import Storage +from tests.classes import User @pytest.mark.asyncio @@ -32,7 +32,7 @@ async def test_refresh_token_grant_type(context): headers=encode_auth_headers(client_id, client_secret), ) - grant_type = RefreshTokenGrantType[Request, Storage]( + grant_type = RefreshTokenGrantType[User]( db, client_id=client_id, client_secret=client_secret ) diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py index b6aea0b..6a81f2f 100644 --- a/tests/test_request_validator.py +++ b/tests/test_request_validator.py @@ -13,14 +13,14 @@ ) from tests import factories -from tests.classes import AuthorizationContext +from tests.classes import AuthorizationContext, User @pytest.mark.asyncio async def test_insecure_transport_error(server: AuthorizationServer): request_url = "http://localhost" - request = Request(url=request_url, method="GET") + request = Request[User](url=request_url, method="GET") response = await server.create_authorization_response(request) assert response.status_code == HTTPStatus.FOUND @@ -182,7 +182,7 @@ async def test_anonymous_user(context: AuthorizationContext): code_challenge=code_challenge, ) - request = Request(url=request_url, query=query, method="GET") + request = Request[User](url=request_url, query=query, method="GET") response = await server.create_authorization_response(request) assert response.status_code == HTTPStatus.UNAUTHORIZED assert response.content["error"] == "invalid_client" @@ -193,7 +193,7 @@ async def test_expired_authorization_code(): settings = factories.settings_factory() client = factories.client_factory(client_secret="") authorization_code = factories.authorization_code_factory( - auth_time=(time.time() - settings.AUTHORIZATION_CODE_EXPIRES_IN), + auth_time=(int(time.time()) - settings.AUTHORIZATION_CODE_EXPIRES_IN), ) context = factories.context_factory( clients=[client], @@ -225,7 +225,7 @@ async def test_expired_refresh_token(): settings = factories.settings_factory() client = factories.client_factory(client_secret="") token = factories.token_factory( - issued_at=(time.time() - (settings.TOKEN_EXPIRES_IN * 2)) + issued_at=(int(time.time()) - (settings.TOKEN_EXPIRES_IN * 2)) ) refresh_token = token.refresh_token context = factories.context_factory( From 3d1b95587a535b75a7994f3f9ea871c073302abd Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Tue, 1 Oct 2024 00:04:52 +0400 Subject: [PATCH 02/57] Refactor grant_type.py and response_type.py for type hinting consistency --- aioauth/grant_type.py | 10 +- aioauth/oidc/core/grant_type.py | 2 +- aioauth/response_type.py | 8 +- aioauth/server.py | 18 ++-- aioauth/storage.py | 159 +++++++++++++++++--------------- 5 files changed, 108 insertions(+), 89 deletions(-) diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 44eb154..578d394 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -37,7 +37,7 @@ def __init__( self.scope: Optional[str] = None async def create_token_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> TokenResponse: """Creates token response to reply to client.""" if self.scope is None: @@ -97,7 +97,7 @@ class AuthorizationCodeGrantType(GrantTypeBase[TUser]): See `RFC 6749 section 1.3.1 `_. """ - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[TUser]) -> Client: client = await super().validate_request(request) if not request.post.redirect_uri: @@ -144,7 +144,7 @@ async def validate_request(self, request: Request) -> Client: return client async def create_token_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> TokenResponse: token_response = await super().create_token_response(request, client) @@ -199,7 +199,7 @@ class RefreshTokenGrantType(GrantTypeBase[TUser]): """ async def create_token_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> TokenResponse: """Validate token request and create token response.""" old_token = await self.storage.get_token( @@ -261,7 +261,7 @@ class ClientCredentialsGrantType(GrantTypeBase[TUser]): See `RFC 6749 section 4.4 `_. """ - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[TUser]) -> Client: # client_credentials grant requires a client_secret if self.client_secret is None: raise InvalidClientError(request) diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index ed453bf..eaaee31 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -35,7 +35,7 @@ class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TUser]): """ async def create_token_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> TokenResponse: """ Creates token response to reply to client. diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 3db01fc..f598d8b 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -106,7 +106,7 @@ class ResponseTypeToken(ResponseTypeBase[TUser]): """Response type that contains a token.""" async def create_authorization_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> TokenResponse: token = await self.storage.create_token( request, @@ -129,7 +129,7 @@ class ResponseTypeAuthorizationCode(ResponseTypeBase[TUser]): """Response type that contains an authorization code.""" async def create_authorization_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> AuthorizationCodeResponse: authorization_code = await self.storage.create_authorization_code( client_id=client.client_id, @@ -162,7 +162,7 @@ async def validate_request(self, request: Request) -> Client: return client async def create_authorization_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> IdTokenResponse: id_token = await self.storage.get_id_token( request, @@ -178,6 +178,6 @@ async def create_authorization_response( class ResponseTypeNone(ResponseTypeBase[TUser]): async def create_authorization_response( - self, request: Request, client: Client + self, request: Request[TUser], client: Client ) -> NoneResponse: return NoneResponse() diff --git a/aioauth/server.py b/aioauth/server.py index 47aee18..2e56dd4 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -105,7 +105,7 @@ def __init__( if grant_types is not None: self.grant_types = grant_types - def is_secure_transport(self, request: Request) -> bool: + def is_secure_transport(self, request: Request[TUser]) -> bool: """ Verifies the request was sent via a protected SSL tunnel. @@ -122,7 +122,9 @@ def is_secure_transport(self, request: Request) -> bool: return True return request.url.lower().startswith("https://") - def validate_request(self, request: Request, allowed_methods: List[RequestMethod]): + def validate_request( + self, request: Request[TUser], allowed_methods: List[RequestMethod] + ): if not request.settings.AVAILABLE: raise TemporarilyUnavailableError(request=request) @@ -136,7 +138,9 @@ def validate_request(self, request: Request, allowed_methods: List[RequestMethod raise MethodNotAllowedError(request=request, headers=headers) @catch_errors_and_unavailability() - async def create_token_introspection_response(self, request: Request) -> Response: + async def create_token_introspection_response( + self, request: Request[TUser] + ) -> Response: """ Returns a response object with introspection of the passed token. For more information see `RFC7662 section 2.1 `_. @@ -221,7 +225,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) def get_client_credentials( - self, request: Request, secret_required: bool + self, request: Request[TUser], secret_required: bool ) -> Tuple[str, str]: client_id = request.post.client_id client_secret = request.post.client_secret @@ -249,7 +253,7 @@ def get_client_credentials( return client_id, client_secret @catch_errors_and_unavailability() - async def create_token_response(self, request: Request) -> Response: + async def create_token_response(self, request: Request[TUser]) -> Response: """Endpoint to obtain an access and/or ID token by presenting an authorization grant or refresh token. Validates a token request and creates a token response. @@ -341,7 +345,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: InvalidRedirectURIError, ) ) - async def create_authorization_response(self, request: Request) -> Response: + async def create_authorization_response(self, request: Request[TUser]) -> Response: """ Endpoint to interact with the resource owner and obtain an authorization grant. @@ -464,7 +468,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: ) @catch_errors_and_unavailability() - async def revoke_token(self, request: Request) -> Response: + async def revoke_token(self, request: Request[TUser]) -> Response: """Endpoint to revoke an access token or refresh token. For more information see `RFC7009 `_. diff --git a/aioauth/storage.py b/aioauth/storage.py index 76260c0..8487328 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -18,7 +18,7 @@ from .requests import Request, TUser -class BaseStorage(Generic[TUser]): +class TokenStorage(Generic[TUser]): async def create_token( self, request: Request[TUser], @@ -29,6 +29,13 @@ async def create_token( ) -> Token: """Generates a user token and stores it in the database. + Used by: + - `ResponseTypeToken` + - `AuthorizationCodeGrantType` + - `PasswordGrantType` + - `ClientCredentialsGrantType` + - `RefreshTokenGrantType` + Warning: Generated token *must* be stored in the database. Note: @@ -45,7 +52,7 @@ async def create_token( async def get_token( self, - request: Request, + request: Request[TUser], client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, @@ -67,9 +74,21 @@ async def get_token( """ raise NotImplementedError("Method get_token must be implemented") + async def revoke_token( + self, + request: Request[TUser], + token_type: Optional[TokenType] = "refresh_token", + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> None: + """Revokes a token from the database.""" + raise NotImplementedError + + +class AuthorizationCodeStorage(Generic[TUser]): async def create_authorization_code( self, - request: Request, + request: Request[TUser], client_id: str, scope: str, response_type: ResponseType, @@ -101,25 +120,48 @@ async def create_authorization_code( "Method create_authorization_code must be implemented" ) - async def get_id_token( - self, - request: Request, - client_id: str, - scope: str, - response_type: ResponseType, - redirect_uri: str, - nonce: Optional[str], - **kwargs, - ) -> str: - """Returns an id_token. - For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 `_. + async def get_authorization_code( + self, request: Request[TUser], client_id: str, code: str + ) -> Optional[AuthorizationCode]: + """Gets existing authorization code from the database if it exists. + Warning: + If authorization code does not exists this function *must* + return ``None`` to indicate to the validator that the + requested authorization code does not exist or is invalid. Note: - Method is used by response type :py:class:`aioauth.response_type.ResponseTypeIdToken` - and :py:class:`aioauth.oidc.core.grant_type.AuthorizationCodeGrantType`. + This method is used by the grant type + :py:class:`aioauth.grant_type.AuthorizationCodeGrantType`. + Args: + request: An :py:class:`aioauth.requests.Request`. + client_id: A user client ID. + code: An authorization code. + Returns: + An optional :py:class:`aioauth.models.AuthorizationCode`. """ - raise NotImplementedError("get_id_token must be implemented.") + raise NotImplementedError( + "Method get_authorization_code must be implemented for AuthorizationCodeGrantType" + ) + + async def delete_authorization_code( + self, request: Request[TUser], client_id: str, code: str + ) -> None: + """Deletes authorization code from database. + + Note: + This method is used by the grant type + :py:class:`aioauth.grant_type.AuthorizationCodeGrantType`. + Args: + request: An :py:class:`aioauth.requests.Request`. + client_id: A user client ID. + code: An authorization code. + """ + raise NotImplementedError( + "Method delete_authorization_code must be implemented for AuthorizationCodeGrantType" + ) + +class ClientStorage(Generic[TUser]): async def get_client( self, request: Request[TUser], @@ -144,6 +186,8 @@ async def get_client( """ raise NotImplementedError("Method get_client must be implemented") + +class Authentication(Generic[TUser]): async def authenticate(self, request: Request[TUser]) -> bool: """Authenticates a user. @@ -158,63 +202,34 @@ async def authenticate(self, request: Request[TUser]) -> bool: """ raise NotImplementedError("Method authenticate must be implemented") - async def get_authorization_code( - self, request: Request[TUser], client_id: str, code: str - ) -> Optional[AuthorizationCode]: - """Gets existing authorization code from the database if it exists. - - Warning: - If authorization code does not exists this function *must* - return ``None`` to indicate to the validator that the - requested authorization code does not exist or is invalid. - Note: - This method is used by the grant type - :py:class:`aioauth.grant_type.AuthorizationCodeGrantType`. - Args: - request: An :py:class:`aioauth.requests.Request`. - client_id: A user client ID. - code: An authorization code. - Returns: - An optional :py:class:`aioauth.models.AuthorizationCode`. - """ - raise NotImplementedError( - "Method get_authorization_code must be implemented for AuthorizationCodeGrantType" - ) - - async def delete_authorization_code( - self, request: Request[TUser], client_id: str, code: str - ) -> None: - """Deletes authorization code from database. - - Note: - This method is used by the grant type - :py:class:`aioauth.grant_type.AuthorizationCodeGrantType`. - Args: - request: An :py:class:`aioauth.requests.Request`. - client_id: A user client ID. - code: An authorization code. - """ - raise NotImplementedError( - "Method delete_authorization_code must be implemented for AuthorizationCodeGrantType" - ) - async def revoke_token( +class IDTokenStorage(Generic[TUser]): + async def get_id_token( self, request: Request[TUser], - token_type: Optional[TokenType] = "refresh_token", - access_token: Optional[str] = None, - refresh_token: Optional[str] = None, - ) -> None: - """Revokes a token's from the database. + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + nonce: Optional[str], + **kwargs, + ) -> str: + """Returns an id_token. + For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 `_. Note: - This method *must* set ``revoked`` to ``True`` for an - existing token record. This method is used by the grant type - :py:class:`aioauth.grant_types.RefreshTokenGrantType`. - Args: - request: An :py:class:`aioauth.requests.Request`. - refresh_token: The user refresh token. + Method is used by response type :py:class:`aioauth.response_type.ResponseTypeIdToken` + and :py:class:`aioauth.oidc.core.grant_type.AuthorizationCodeGrantType`. """ - raise NotImplementedError( - "Method revoke_token must be implemented for RefreshTokenGrantType" - ) + raise NotImplementedError("get_id_token must be implemented.") + + +class BaseStorage( + Generic[TUser], + TokenStorage[TUser], + AuthorizationCodeStorage[TUser], + ClientStorage[TUser], + Authentication[TUser], + IDTokenStorage[TUser], +): + ... From f8d0fe9548d957c3886e85833afa15345f65dea8 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Mon, 7 Oct 2024 00:44:19 +0400 Subject: [PATCH 03/57] feat: removed redundant types, added debug mode --- aioauth/config.py | 2 ++ aioauth/errors.py | 1 + aioauth/oidc/core/requests.py | 26 +++----------------------- aioauth/requests.py | 14 ++++---------- aioauth/utils.py | 15 ++++++++++----- tests/oidc/core/test_flow.py | 5 +++-- tests/test_flow.py | 2 +- 7 files changed, 24 insertions(+), 41 deletions(-) diff --git a/aioauth/config.py b/aioauth/config.py index 2874af1..ab204df 100644 --- a/aioauth/config.py +++ b/aioauth/config.py @@ -38,3 +38,5 @@ class Settings: AVAILABLE: bool = True """Boolean indicating whether or not the server is available.""" + + DEBUG: bool = False diff --git a/aioauth/errors.py b/aioauth/errors.py index 90be9aa..1bec33d 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -184,6 +184,7 @@ class ServerError(OAuth2Error): """ error: ErrorType = "server_error" + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST class TemporarilyUnavailableError(OAuth2Error): diff --git a/aioauth/oidc/core/requests.py b/aioauth/oidc/core/requests.py index 03d1f94..4cfe118 100644 --- a/aioauth/oidc/core/requests.py +++ b/aioauth/oidc/core/requests.py @@ -1,12 +1,10 @@ from dataclasses import dataclass, field -from typing import Any, Optional, TypeVar +from typing import Optional from aioauth.requests import ( - BaseRequest as BaseOAuth2Request, + BaseRequest, Query as OAuth2Query, - Post, - TPost, TUser, ) @@ -21,26 +19,8 @@ class Query(OAuth2Query): prompt: Optional[str] = None -TQuery = TypeVar("TQuery", bound=Query) - - -@dataclass -class BaseRequest(BaseOAuth2Request[TQuery, TPost, TUser]): - """ - Object that contains a client's complete request with extensions as defined - by OpenID Core. - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - """ - - query: TQuery - post: TPost - user: Optional[TUser] = None - - @dataclass -class Request(BaseRequest[Query, Post, Any]): +class Request(BaseRequest[TUser]): """Object that contains a client's complete request.""" query: Query = field(default_factory=Query) - post: Post = field(default_factory=Post) - user: Optional[Any] = None diff --git a/aioauth/requests.py b/aioauth/requests.py index b7c4f52..1bd76b5 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -60,16 +60,14 @@ class Post: code_verifier: Optional[str] = None -TQuery = TypeVar("TQuery", bound=Query) -TPost = TypeVar("TPost", bound=Post) TUser = TypeVar("TUser") @dataclass -class BaseRequest(Generic[TQuery, TPost, TUser]): +class BaseRequest(Generic[TUser]): method: RequestMethod - query: TQuery - post: TPost + query: Query = field(default_factory=Query) + post: Post = field(default_factory=Post) headers: HTTPHeaderDict = field(default_factory=HTTPHeaderDict) url: str = "" user: Optional[TUser] = None @@ -77,9 +75,5 @@ class BaseRequest(Generic[TQuery, TPost, TUser]): @dataclass -class Request(Generic[TUser], BaseRequest[Query, Post, TUser]): +class Request(Generic[TUser], BaseRequest[TUser]): """Object that contains a client's complete request.""" - - query: Query = field(default_factory=Query) - post: Post = field(default_factory=Post) - user: Optional[TUser] = None diff --git a/aioauth/utils.py b/aioauth/utils.py index 1257c08..cd9cf0d 100644 --- a/aioauth/utils.py +++ b/aioauth/utils.py @@ -34,6 +34,8 @@ ) from urllib.parse import quote, urlencode, urlparse, urlunsplit +from aioauth.requests import Request + from .collections import HTTPHeaderDict from .errors import ( OAuth2Error, @@ -153,8 +155,8 @@ def build_uri( parsed_url.scheme, parsed_url.netloc, parsed_url.path, - urlencode(query_params, quote_via=quote), # type: ignore - urlencode(fragment, quote_via=quote), # type: ignore + urlencode(query_params, quote_via=quote), + urlencode(fragment, quote_via=quote), ) ) return uri @@ -239,7 +241,7 @@ def catch_errors_and_unavailability( def decorator(f) -> Callable[..., Coroutine[Any, Any, Response]]: @functools.wraps(f) - async def wrapper(self, request, *args, **kwargs) -> Response: + async def wrapper(self, request: Request, *args, **kwargs) -> Response: error: Union[TemporarilyUnavailableError, ServerError] try: @@ -268,8 +270,11 @@ async def wrapper(self, request, *args, **kwargs) -> Response: status_code=HTTPStatus.FOUND, headers=HTTPHeaderDict({"location": location}), ) - except Exception: - error = ServerError(request=request) + except Exception as exc: + error = ServerError( + request=request, + description=str(exc) if request.settings.DEBUG else "", + ) log.exception("Exception caught while processing request.") content = ErrorResponse( error=error.error, description=error.description diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py index ec94aa7..57047be 100644 --- a/tests/oidc/core/test_flow.py +++ b/tests/oidc/core/test_flow.py @@ -8,6 +8,7 @@ generate_token, ) +from tests.classes import User from tests.utils import check_request_validators @@ -21,7 +22,7 @@ ) async def test_authorization_endpoint_allows_prompt_query_param( expected_status_code: HTTPStatus, - user: Optional[str], + user: Optional[User], context_factory, ): context = context_factory(users={user, "password"}) @@ -39,7 +40,7 @@ async def test_authorization_endpoint_allows_prompt_query_param( state=generate_token(10), ) - request = Request( + request = Request[User]( url=request_url, query=query, method="GET", diff --git a/tests/test_flow.py b/tests/test_flow.py index 46a1422..51465b3 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -47,7 +47,7 @@ async def test_authorization_code_flow_plain_code_challenge(): url=request_url, query=query, method="GET", - user=username, + user=User(first_name="A", last_name="B"), ) await check_request_validators(request, server.create_authorization_response) From 10dd03a7baa600a7bd9ed54aeb8196e3e3108891 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 24 Oct 2024 11:03:32 +0400 Subject: [PATCH 04/57] fix: typing issues and mypy --- .pre-commit-config.yaml | 9 ++++-- aioauth/grant_type.py | 34 +++++++++++++---------- aioauth/models.py | 16 +++++------ aioauth/oidc/core/grant_type.py | 7 +++-- aioauth/oidc/core/requests.py | 6 ++-- aioauth/requests.py | 12 ++++---- aioauth/response_type.py | 27 +++++++++--------- aioauth/server.py | 49 +++++++++++++++++---------------- aioauth/storage.py | 45 +++++++++++++++--------------- aioauth/types.py | 8 ++++++ setup.cfg | 2 +- tests/classes.py | 9 ++---- tests/oidc/core/test_flow.py | 5 +++- 13 files changed, 124 insertions(+), 105 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a2a7d07..76dfa83 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,10 +19,15 @@ repos: - id: check-merge-conflict - id: detect-private-key - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.950 + - repo: local hooks: - id: mypy + name: mypy + entry: "mypy" + language: system + types: [python] + require_serial: true + verbose: false exclude: ^(docs/|setup\.py) - repo: https://github.com/pycqa/flake8 diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 578d394..c628a44 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -9,7 +9,8 @@ """ from typing import Generic, Optional -from .requests import Request, TUser +from .requests import Request +from .types import UserType from .storage import BaseStorage from .errors import ( InvalidClientError, @@ -25,11 +26,14 @@ from .utils import enforce_list, enforce_str, generate_token -class GrantTypeBase(Generic[TUser]): +class GrantTypeBase(Generic[UserType]): """Base grant type that all other grant types inherit from.""" def __init__( - self, storage: BaseStorage[TUser], client_id: str, client_secret: Optional[str] + self, + storage: BaseStorage[UserType], + client_id: str, + client_secret: Optional[str], ): self.storage = storage self.client_id = client_id @@ -37,7 +41,7 @@ def __init__( self.scope: Optional[str] = None async def create_token_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: """Creates token response to reply to client.""" if self.scope is None: @@ -60,7 +64,7 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: """Validates the client request to ensure it is valid.""" client = await self.storage.get_client( request, client_id=self.client_id, client_secret=self.client_secret @@ -81,7 +85,7 @@ async def validate_request(self, request: Request) -> Client: return client -class AuthorizationCodeGrantType(GrantTypeBase[TUser]): +class AuthorizationCodeGrantType(GrantTypeBase[UserType]): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -97,7 +101,7 @@ class AuthorizationCodeGrantType(GrantTypeBase[TUser]): See `RFC 6749 section 1.3.1 `_. """ - async def validate_request(self, request: Request[TUser]) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: client = await super().validate_request(request) if not request.post.redirect_uri: @@ -144,7 +148,7 @@ async def validate_request(self, request: Request[TUser]) -> Client: return client async def create_token_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: token_response = await super().create_token_response(request, client) @@ -160,7 +164,7 @@ async def create_token_response( return token_response -class PasswordGrantType(GrantTypeBase[TUser]): +class PasswordGrantType(GrantTypeBase[UserType]): """ The Password grant type is a way to exchange a user's credentials for an access token. Because the client application has to collect @@ -171,7 +175,7 @@ class PasswordGrantType(GrantTypeBase[TUser]): disallows the password grant entirely. """ - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: client = await super().validate_request(request) if not request.post.username or not request.post.password: @@ -189,7 +193,7 @@ async def validate_request(self, request: Request) -> Client: return client -class RefreshTokenGrantType(GrantTypeBase[TUser]): +class RefreshTokenGrantType(GrantTypeBase[UserType]): """ The Refresh Token grant type is used by clients to exchange a refresh token for an access token when the access token has expired. @@ -199,7 +203,7 @@ class RefreshTokenGrantType(GrantTypeBase[TUser]): """ async def create_token_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: """Validate token request and create token response.""" old_token = await self.storage.get_token( @@ -241,7 +245,7 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: client = await super().validate_request(request) if not request.post.refresh_token: @@ -252,7 +256,7 @@ async def validate_request(self, request: Request) -> Client: return client -class ClientCredentialsGrantType(GrantTypeBase[TUser]): +class ClientCredentialsGrantType(GrantTypeBase[UserType]): """ The Client Credentials grant type is used by clients to obtain an access token outside of the context of a user. This is typically @@ -261,7 +265,7 @@ class ClientCredentialsGrantType(GrantTypeBase[TUser]): See `RFC 6749 section 4.4 `_. """ - async def validate_request(self, request: Request[TUser]) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: # client_credentials grant requires a client_secret if self.client_secret is None: raise InvalidClientError(request) diff --git a/aioauth/models.py b/aioauth/models.py index a6a496c..6b359e3 100644 --- a/aioauth/models.py +++ b/aioauth/models.py @@ -9,14 +9,14 @@ """ from dataclasses import dataclass import time -from typing import Any, List, Optional, Union +from typing import Generic, List, Optional, Union -from .types import CodeChallengeMethod, GrantType, ResponseType, TokenType +from .types import CodeChallengeMethod, GrantType, ResponseType, TokenType, UserType from .utils import create_s256_code_challenge, enforce_list, enforce_str @dataclass -class Client: +class Client(Generic[UserType]): """OAuth2.0 client model object.""" client_id: str @@ -62,7 +62,7 @@ class Client: scopes granted. """ - user: Optional[Any] = None + user: Optional[UserType] = None """ The user who is the creator of the Client. This optional attribute can be useful if you are creating a server that @@ -112,7 +112,7 @@ def check_scope(self, scope: str) -> bool: @dataclass -class AuthorizationCode: +class AuthorizationCode(Generic[UserType]): code: str """ Authorization code that the client previously received from the @@ -184,7 +184,7 @@ class AuthorizationCode: Random piece of data. """ - user: Optional[Any] = None + user: Optional[UserType] = None """ The user who owns the AuthorizationCode. """ @@ -211,7 +211,7 @@ def is_expired(self) -> bool: @dataclass -class Token: +class Token(Generic[UserType]): access_token: str """ Token that clients use to make API requests on behalf of the @@ -264,7 +264,7 @@ class Token: Flag that indicates whether or not the token has been revoked. """ - user: Optional[Any] = None + user: Optional[UserType] = None """ The user who owns the Token. """ diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index eaaee31..383c68d 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -14,11 +14,12 @@ ) from ...models import Client from ...oidc.core.responses import TokenResponse -from ...requests import Request, TUser +from ...requests import Request +from ...types import UserType from ...utils import generate_token -class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TUser]): +class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[UserType]): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -35,7 +36,7 @@ class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[TUser]): """ async def create_token_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: """ Creates token response to reply to client. diff --git a/aioauth/oidc/core/requests.py b/aioauth/oidc/core/requests.py index 4cfe118..ed9b4be 100644 --- a/aioauth/oidc/core/requests.py +++ b/aioauth/oidc/core/requests.py @@ -2,11 +2,11 @@ from typing import Optional -from aioauth.requests import ( +from ...requests import ( BaseRequest, Query as OAuth2Query, - TUser, ) +from ...types import UserType @dataclass @@ -20,7 +20,7 @@ class Query(OAuth2Query): @dataclass -class Request(BaseRequest[TUser]): +class Request(BaseRequest[UserType]): """Object that contains a client's complete request.""" query: Query = field(default_factory=Query) diff --git a/aioauth/requests.py b/aioauth/requests.py index 1bd76b5..2db1de7 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -8,7 +8,7 @@ ---- """ from dataclasses import dataclass, field -from typing import Generic, Optional, TypeVar +from typing import Generic, Optional from .collections import HTTPHeaderDict from .config import Settings @@ -18,6 +18,7 @@ RequestMethod, ResponseMode, TokenType, + UserType, ) @@ -60,20 +61,17 @@ class Post: code_verifier: Optional[str] = None -TUser = TypeVar("TUser") - - @dataclass -class BaseRequest(Generic[TUser]): +class BaseRequest(Generic[UserType]): method: RequestMethod query: Query = field(default_factory=Query) post: Post = field(default_factory=Post) headers: HTTPHeaderDict = field(default_factory=HTTPHeaderDict) url: str = "" - user: Optional[TUser] = None + user: Optional[UserType] = None settings: Settings = field(default_factory=Settings) @dataclass -class Request(Generic[TUser], BaseRequest[TUser]): +class Request(Generic[UserType], BaseRequest[UserType]): """Object that contains a client's complete request.""" diff --git a/aioauth/response_type.py b/aioauth/response_type.py index f598d8b..023dee4 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -10,7 +10,8 @@ import sys from typing import Generic, Tuple -from .requests import Request, TUser +from .requests import Request +from .types import UserType from .storage import BaseStorage if sys.version_info >= (3, 8): @@ -36,13 +37,13 @@ from .types import CodeChallengeMethod -class ResponseTypeBase(Generic[TUser]): +class ResponseTypeBase(Generic[UserType]): """Base response type that all other exceptions inherit from.""" - def __init__(self, storage: BaseStorage[TUser]): + def __init__(self, storage: BaseStorage[UserType]): self.storage = storage - async def validate_request(self, request: Request) -> Client: + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: state = request.query.state code_challenge_methods: Tuple[CodeChallengeMethod, ...] = get_args( @@ -102,11 +103,11 @@ async def validate_request(self, request: Request) -> Client: return client -class ResponseTypeToken(ResponseTypeBase[TUser]): +class ResponseTypeToken(ResponseTypeBase[UserType]): """Response type that contains a token.""" async def create_authorization_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: token = await self.storage.create_token( request, @@ -125,11 +126,11 @@ async def create_authorization_response( ) -class ResponseTypeAuthorizationCode(ResponseTypeBase[TUser]): +class ResponseTypeAuthorizationCode(ResponseTypeBase[UserType]): """Response type that contains an authorization code.""" async def create_authorization_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> AuthorizationCodeResponse: authorization_code = await self.storage.create_authorization_code( client_id=client.client_id, @@ -148,8 +149,8 @@ async def create_authorization_response( ) -class ResponseTypeIdToken(ResponseTypeBase[TUser]): - async def validate_request(self, request: Request) -> Client: +class ResponseTypeIdToken(ResponseTypeBase[UserType]): + async def validate_request(self, request: Request[UserType]) -> Client[UserType]: client = await super().validate_request(request) # nonce is required for id_token @@ -162,7 +163,7 @@ async def validate_request(self, request: Request) -> Client: return client async def create_authorization_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> IdTokenResponse: id_token = await self.storage.get_id_token( request, @@ -176,8 +177,8 @@ async def create_authorization_response( return IdTokenResponse(id_token=id_token) -class ResponseTypeNone(ResponseTypeBase[TUser]): +class ResponseTypeNone(ResponseTypeBase[UserType]): async def create_authorization_response( - self, request: Request[TUser], client: Client + self, request: Request[UserType], client: Client[UserType] ) -> NoneResponse: return NoneResponse() diff --git a/aioauth/server.py b/aioauth/server.py index 2e56dd4..5f27288 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -21,7 +21,8 @@ from http import HTTPStatus from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union -from .requests import Request, TUser +from .requests import Request +from .types import UserType from .storage import BaseStorage @@ -75,25 +76,25 @@ ) -class AuthorizationServer(Generic[TUser]): +class AuthorizationServer(Generic[UserType]): """Interface for initializing an OAuth 2.0 server.""" response_types: Dict[ResponseType, Any] = { - "token": ResponseTypeToken[TUser], - "code": ResponseTypeAuthorizationCode[TUser], - "none": ResponseTypeNone[TUser], - "id_token": ResponseTypeIdToken[TUser], + "token": ResponseTypeToken[UserType], + "code": ResponseTypeAuthorizationCode[UserType], + "none": ResponseTypeNone[UserType], + "id_token": ResponseTypeIdToken[UserType], } grant_types: Dict[GrantType, Any] = { - "authorization_code": AuthorizationCodeGrantType[TUser], - "client_credentials": ClientCredentialsGrantType[TUser], - "password": PasswordGrantType[TUser], - "refresh_token": RefreshTokenGrantType[TUser], + "authorization_code": AuthorizationCodeGrantType[UserType], + "client_credentials": ClientCredentialsGrantType[UserType], + "password": PasswordGrantType[UserType], + "refresh_token": RefreshTokenGrantType[UserType], } def __init__( self, - storage: BaseStorage[TUser], + storage: BaseStorage[UserType], response_types: Optional[Dict] = None, grant_types: Optional[Dict] = None, ): @@ -105,7 +106,7 @@ def __init__( if grant_types is not None: self.grant_types = grant_types - def is_secure_transport(self, request: Request[TUser]) -> bool: + def is_secure_transport(self, request: Request[UserType]) -> bool: """ Verifies the request was sent via a protected SSL tunnel. @@ -123,7 +124,7 @@ def is_secure_transport(self, request: Request[TUser]) -> bool: return request.url.lower().startswith("https://") def validate_request( - self, request: Request[TUser], allowed_methods: List[RequestMethod] + self, request: Request[UserType], allowed_methods: List[RequestMethod] ): if not request.settings.AVAILABLE: raise TemporarilyUnavailableError(request=request) @@ -139,7 +140,7 @@ def validate_request( @catch_errors_and_unavailability() async def create_token_introspection_response( - self, request: Request[TUser] + self, request: Request[UserType] ) -> Response: """ Returns a response object with introspection of the passed token. @@ -225,7 +226,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) def get_client_credentials( - self, request: Request[TUser], secret_required: bool + self, request: Request[UserType], secret_required: bool ) -> Tuple[str, str]: client_id = request.post.client_id client_secret = request.post.client_secret @@ -253,7 +254,7 @@ def get_client_credentials( return client_id, client_secret @catch_errors_and_unavailability() - async def create_token_response(self, request: Request[TUser]) -> Response: + async def create_token_response(self, request: Request[UserType]) -> Response: """Endpoint to obtain an access and/or ID token by presenting an authorization grant or refresh token. Validates a token request and creates a token response. @@ -311,11 +312,11 @@ async def token(request: fastapi.Request) -> fastapi.Response: GrantTypeClass: Type[ Union[ - GrantTypeBase[TUser], - AuthorizationCodeGrantType[TUser], - PasswordGrantType[TUser], - RefreshTokenGrantType[TUser], - ClientCredentialsGrantType[TUser], + GrantTypeBase[UserType], + AuthorizationCodeGrantType[UserType], + PasswordGrantType[UserType], + RefreshTokenGrantType[UserType], + ClientCredentialsGrantType[UserType], ] ] @@ -345,7 +346,9 @@ async def token(request: fastapi.Request) -> fastapi.Response: InvalidRedirectURIError, ) ) - async def create_authorization_response(self, request: Request[TUser]) -> Response: + async def create_authorization_response( + self, request: Request[UserType] + ) -> Response: """ Endpoint to interact with the resource owner and obtain an authorization grant. @@ -468,7 +471,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: ) @catch_errors_and_unavailability() - async def revoke_token(self, request: Request[TUser]) -> Response: + async def revoke_token(self, request: Request[UserType]) -> Response: """Endpoint to revoke an access token or refresh token. For more information see `RFC7009 `_. diff --git a/aioauth/storage.py b/aioauth/storage.py index 8487328..40f2128 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -15,13 +15,14 @@ from .models import AuthorizationCode, Client, Token from .types import CodeChallengeMethod, ResponseType, TokenType -from .requests import Request, TUser +from .requests import Request +from .types import UserType -class TokenStorage(Generic[TUser]): +class TokenStorage(Generic[UserType]): async def create_token( self, - request: Request[TUser], + request: Request[UserType], client_id: str, scope: str, access_token: str, @@ -52,7 +53,7 @@ async def create_token( async def get_token( self, - request: Request[TUser], + request: Request[UserType], client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, @@ -76,7 +77,7 @@ async def get_token( async def revoke_token( self, - request: Request[TUser], + request: Request[UserType], token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, refresh_token: Optional[str] = None, @@ -85,10 +86,10 @@ async def revoke_token( raise NotImplementedError -class AuthorizationCodeStorage(Generic[TUser]): +class AuthorizationCodeStorage(Generic[UserType]): async def create_authorization_code( self, - request: Request[TUser], + request: Request[UserType], client_id: str, scope: str, response_type: ResponseType, @@ -121,7 +122,7 @@ async def create_authorization_code( ) async def get_authorization_code( - self, request: Request[TUser], client_id: str, code: str + self, request: Request[UserType], client_id: str, code: str ) -> Optional[AuthorizationCode]: """Gets existing authorization code from the database if it exists. @@ -144,7 +145,7 @@ async def get_authorization_code( ) async def delete_authorization_code( - self, request: Request[TUser], client_id: str, code: str + self, request: Request[UserType], client_id: str, code: str ) -> None: """Deletes authorization code from database. @@ -161,13 +162,13 @@ async def delete_authorization_code( ) -class ClientStorage(Generic[TUser]): +class ClientStorage(Generic[UserType]): async def get_client( self, - request: Request[TUser], + request: Request[UserType], client_id: str, client_secret: Optional[str] = None, - ) -> Optional[Client]: + ) -> Optional[Client[UserType]]: """Gets existing client from the database if it exists. Warning: @@ -187,8 +188,8 @@ async def get_client( raise NotImplementedError("Method get_client must be implemented") -class Authentication(Generic[TUser]): - async def authenticate(self, request: Request[TUser]) -> bool: +class Authentication(Generic[UserType]): + async def authenticate(self, request: Request[UserType]) -> bool: """Authenticates a user. Note: @@ -203,10 +204,10 @@ async def authenticate(self, request: Request[TUser]) -> bool: raise NotImplementedError("Method authenticate must be implemented") -class IDTokenStorage(Generic[TUser]): +class IDTokenStorage(Generic[UserType]): async def get_id_token( self, - request: Request[TUser], + request: Request[UserType], client_id: str, scope: str, response_type: ResponseType, @@ -225,11 +226,11 @@ async def get_id_token( class BaseStorage( - Generic[TUser], - TokenStorage[TUser], - AuthorizationCodeStorage[TUser], - ClientStorage[TUser], - Authentication[TUser], - IDTokenStorage[TUser], + Generic[UserType], + TokenStorage[UserType], + AuthorizationCodeStorage[UserType], + ClientStorage[UserType], + Authentication[UserType], + IDTokenStorage[UserType], ): ... diff --git a/aioauth/types.py b/aioauth/types.py index a5c6cf7..661e5ec 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -8,6 +8,12 @@ ---- """ import sys +from typing import Any + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar if sys.version_info >= (3, 8): from typing import Literal @@ -65,3 +71,5 @@ TokenType = Literal["access_token", "refresh_token", "Bearer"] + +UserType = TypeVar("UserType", default=Any) diff --git a/setup.cfg b/setup.cfg index 6b40ce4..8ec772f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts = -s --strict-markers -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail +addopts = -s --strict-markers -vv --cache-clear --maxfail=1 [coverage:run] branch = True diff --git a/tests/classes.py b/tests/classes.py index caa0d7a..6ddc65a 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -20,17 +20,12 @@ from backports.cached_property import cached_property -@dataclass +@dataclass(frozen=True) class User: first_name: str last_name: str -# @dataclass -# class Request(BaseRequest[Query, Post, User]): -# ... - - class Storage(BaseStorage[User]): def __init__( self, @@ -73,7 +68,7 @@ async def create_token( access_token: str, refresh_token: str, ): - token = Token( + token: Token[User] = Token( client_id=client_id, expires_in=request.settings.TOKEN_EXPIRES_IN, refresh_token_expires_in=request.settings.REFRESH_TOKEN_EXPIRES_IN, diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py index 57047be..dedf556 100644 --- a/tests/oidc/core/test_flow.py +++ b/tests/oidc/core/test_flow.py @@ -25,7 +25,10 @@ async def test_authorization_endpoint_allows_prompt_query_param( user: Optional[User], context_factory, ): - context = context_factory(users={user, "password"}) + if user is None: + context = context_factory() + else: + context = context_factory(users={user: "password"}) server = context.server client = context.clients[0] client_id = client.client_id From e705785de4724ac0ffe4ee53a79ea14050ba1b55 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 24 Oct 2024 11:09:44 +0400 Subject: [PATCH 05/57] fix: added mypy --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 088a7eb..acbe541 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "testfixtures==6.18.3", "twine==3.7.1", "wheel", + "mypy==1.13.0", ] require_docs = [ From 02dcd3eaa315734acd6f4a21034be565c508ac77 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 24 Oct 2024 11:24:56 +0400 Subject: [PATCH 06/57] fix: upd python version >= 3.9 --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 3bc145d..ff49fc2 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.11' - name: Install dependencies run: | make dev-install diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 297ff7e..6c8918f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v2 diff --git a/setup.py b/setup.py index acbe541..6d390a5 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ url=about["__url__"], license=about["__license__"], package_data={"aioauth": ["py.typed"]}, - python_requires=">=3.7.0", + python_requires=">=3.9.0", classifiers=classifiers, install_requires=["typing_extensions"], extras_require={ From 3ea1ecaecbee393a6078afa7842907817c8cda83 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 24 Oct 2024 11:26:18 +0400 Subject: [PATCH 07/57] fix: coverage report --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8ec772f..6b40ce4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts = -s --strict-markers -vv --cache-clear --maxfail=1 +addopts = -s --strict-markers -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail [coverage:run] branch = True From 9c07eed50870ab22398ea71787eccbabdfb707bc Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 24 Oct 2024 23:54:45 +0400 Subject: [PATCH 08/57] fix: formatting --- .pre-commit-config.yaml | 6 +++--- aioauth/grant_type.py | 5 +++-- aioauth/models.py | 1 + aioauth/oidc/core/grant_type.py | 1 + aioauth/requests.py | 1 + aioauth/response_type.py | 1 + aioauth/responses.py | 1 + aioauth/server.py | 1 + aioauth/storage.py | 13 ++++++------- aioauth/types.py | 1 + setup.cfg | 5 ++++- setup.py | 2 +- tests/classes.py | 14 ++++++++++---- tests/test_db.py | 2 +- tests/test_flow.py | 2 +- 15 files changed, 36 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76dfa83..9095212 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ fail_fast: true repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.10.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude: ^(setup\.cfg) @@ -31,6 +31,6 @@ repos: exclude: ^(docs/|setup\.py) - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.1.1 hooks: - id: flake8 diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index c628a44..fd07531 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -7,6 +7,7 @@ ---- """ + from typing import Generic, Optional from .requests import Request @@ -183,9 +184,9 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] request=request, description="Invalid credentials given." ) - user = await self.storage.authenticate(request) + user = await self.storage.get_user(request) - if not user: + if user is None: raise InvalidRequestError( request=request, description="Invalid credentials given." ) diff --git a/aioauth/models.py b/aioauth/models.py index 6b359e3..8b19273 100644 --- a/aioauth/models.py +++ b/aioauth/models.py @@ -7,6 +7,7 @@ ---- """ + from dataclasses import dataclass import time from typing import Generic, List, Optional, Union diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index 383c68d..2c4a5a8 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -7,6 +7,7 @@ ---- """ + from typing import TYPE_CHECKING from ...grant_type import ( diff --git a/aioauth/requests.py b/aioauth/requests.py index 2db1de7..65c3bc7 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -7,6 +7,7 @@ ---- """ + from dataclasses import dataclass, field from typing import Generic, Optional diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 023dee4..23c2987 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -7,6 +7,7 @@ ---- """ + import sys from typing import Generic, Tuple diff --git a/aioauth/responses.py b/aioauth/responses.py index 63c02c7..086608a 100644 --- a/aioauth/responses.py +++ b/aioauth/responses.py @@ -7,6 +7,7 @@ ---- """ + from dataclasses import dataclass, field from http import HTTPStatus from typing import Dict diff --git a/aioauth/server.py b/aioauth/server.py index 5f27288..bf4ab30 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -16,6 +16,7 @@ ---- """ + import sys from dataclasses import asdict from http import HTTPStatus diff --git a/aioauth/storage.py b/aioauth/storage.py index 40f2128..1344603 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -188,9 +188,9 @@ async def get_client( raise NotImplementedError("Method get_client must be implemented") -class Authentication(Generic[UserType]): - async def authenticate(self, request: Request[UserType]) -> bool: - """Authenticates a user. +class UserStorage(Generic[UserType]): + async def get_user(self, request: Request[UserType]) -> Optional[UserType]: + """Returns a user. Note: This method is used by the grant type @@ -201,7 +201,7 @@ async def authenticate(self, request: Request[UserType]) -> bool: Boolean indicating whether or not the user was authenticated successfully. """ - raise NotImplementedError("Method authenticate must be implemented") + raise NotImplementedError("Method get_user must be implemented") class IDTokenStorage(Generic[UserType]): @@ -230,7 +230,6 @@ class BaseStorage( TokenStorage[UserType], AuthorizationCodeStorage[UserType], ClientStorage[UserType], - Authentication[UserType], + UserStorage[UserType], IDTokenStorage[UserType], -): - ... +): ... diff --git a/aioauth/types.py b/aioauth/types.py index 661e5ec..54fb14c 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -7,6 +7,7 @@ ---- """ + import sys from typing import Any diff --git a/setup.cfg b/setup.cfg index 6b40ce4..522a045 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,8 +16,11 @@ allow_redefinition = True [bdist_wheel] universal = 1 +[tool:isort] +profile = "black" + [flake8] -ignore = D10,E203,E501,W503,D205,D400,A001,D210,D401 +ignore = D10,E203,E501,W503,D205,D400,A001,D210,D401,E701 max-line-length = 88 select = A,B,C4,D,E,F,M,Q,T,W,ABS,BLK exclude = versions/* diff --git a/setup.py b/setup.py index 6d390a5..841e4e5 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ require_dev = [ "async-asgi-testclient==1.4.8", "backports.cached-property==1.0.2", - "pre-commit==2.16.0", + "pre_commit==4.0.1", "pytest==6.2.5", "pytest-asyncio==0.16.0", "pytest-cov==3.0.0", diff --git a/tests/classes.py b/tests/classes.py index 6ddc65a..13e4577 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -22,8 +22,7 @@ @dataclass(frozen=True) class User: - first_name: str - last_name: str + username: str class Storage(BaseStorage[User]): @@ -117,10 +116,17 @@ async def get_token( ): return token_ - async def authenticate(self, request: Request[User]) -> bool: + async def get_user(self, request: Request[User]) -> Optional[User]: password = request.post.password username = request.post.username - return username in self.users and self.users[username] == password + + if username is None or password is None: + return None + + user_exists = username in self.users and self.users[username] == password + + if user_exists: + return User(username=username) async def create_authorization_code( self, diff --git a/tests/test_db.py b/tests/test_db.py index 3838b20..d45bbad 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -51,7 +51,7 @@ async def test_storage_class() -> None: client_secret=client.client_secret, ) with pytest.raises(NotImplementedError): - await db.authenticate(request=request) + await db.get_user(request=request) with pytest.raises(NotImplementedError): await db.get_authorization_code( request=request, client_id=client.client_id, code=authorization_code.code diff --git a/tests/test_flow.py b/tests/test_flow.py index 51465b3..69b4923 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -47,7 +47,7 @@ async def test_authorization_code_flow_plain_code_challenge(): url=request_url, query=query, method="GET", - user=User(first_name="A", last_name="B"), + user=User(username="A"), ) await check_request_validators(request, server.create_authorization_response) From 09998b0db9551dd5b3b271899cebde2eaddcd30f Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 26 Oct 2024 10:45:52 +0400 Subject: [PATCH 09/57] fix: the BaseRequest was renamed to Request --- aioauth/oidc/core/requests.py | 6 +++--- aioauth/requests.py | 9 +++------ tests/utils.py | 6 +++--- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/aioauth/oidc/core/requests.py b/aioauth/oidc/core/requests.py index ed9b4be..8d58234 100644 --- a/aioauth/oidc/core/requests.py +++ b/aioauth/oidc/core/requests.py @@ -3,14 +3,14 @@ from typing import Optional from ...requests import ( - BaseRequest, - Query as OAuth2Query, + Request as BaseRequest, + Query as BaseQuery, ) from ...types import UserType @dataclass -class Query(OAuth2Query): +class Query(BaseQuery): # Space delimited, case sensitive list of ASCII string values that # specifies whether the Authorization Server prompts the End-User for # reauthentication and consent. The defined values are: none, login, diff --git a/aioauth/requests.py b/aioauth/requests.py index 65c3bc7..e69d560 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -63,7 +63,9 @@ class Post: @dataclass -class BaseRequest(Generic[UserType]): +class Request(Generic[UserType]): + """Object that contains a client's complete request.""" + method: RequestMethod query: Query = field(default_factory=Query) post: Post = field(default_factory=Post) @@ -71,8 +73,3 @@ class BaseRequest(Generic[UserType]): url: str = "" user: Optional[UserType] = None settings: Settings = field(default_factory=Settings) - - -@dataclass -class Request(Generic[UserType], BaseRequest[UserType]): - """Object that contains a client's complete request.""" diff --git a/tests/utils.py b/tests/utils.py index 08fe5e0..ad65aa9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ from aioauth.collections import HTTPHeaderDict from aioauth.constances import default_headers -from aioauth.requests import BaseRequest, Post, Query +from aioauth.requests import Request, Post, Query from aioauth.responses import ErrorResponse, Response EMPTY_KEYS = { @@ -298,7 +298,7 @@ def get_keys(query: Union[Query, Post]) -> Dict[str, Any]: async def check_query_values( - request: BaseRequest, responses, query_dict: Dict, endpoint_func, value + request: Request, responses, query_dict: Dict, endpoint_func, value ): keys = set(query_dict.keys()) & set(responses.keys()) @@ -333,7 +333,7 @@ async def check_query_values( async def check_request_validators( - request: BaseRequest, + request: Request, endpoint_func: Callable, ): query_dict = {} From 792720f74f80a70b1d8d1876321e7c7f8d4f683d Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 3 Nov 2024 08:57:31 +0400 Subject: [PATCH 10/57] feat: added bandit to pre-commit --- .pre-commit-config.yaml | 6 ++++++ setup.py | 1 + 2 files changed, 7 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9095212..3ac4649 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,12 @@ repos: verbose: false exclude: ^(docs/|setup\.py) + - repo: https://github.com/PyCQA/bandit + rev: 1.7.10 + hooks: + - id: bandit + args: ["--skip", "B101"] + - repo: https://github.com/pycqa/flake8 rev: 7.1.1 hooks: diff --git a/setup.py b/setup.py index 841e4e5..a28cdcc 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "twine==3.7.1", "wheel", "mypy==1.13.0", + "bandit==1.7.10", ] require_docs = [ From 26ef51fedb1bbce13890731cac8c378dddd025d7 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 3 Nov 2024 09:09:47 +0400 Subject: [PATCH 11/57] feat: added steps to ci.yml --- .github/workflows/ci.yml | 47 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c8918f..3c796b1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,3 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - name: CI on: @@ -10,29 +7,59 @@ on: branches: [ master ] jobs: - build: + checkout: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + setup-python: runs-on: ubuntu-latest + needs: checkout strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - steps: - - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + + install-dependencies: + runs-on: ubuntu-latest + needs: setup-python + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: - name: Install dependencies run: | make dev-install pip install codecov + + lint: + runs-on: ubuntu-latest + needs: install-dependencies + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: - name: Run lint - run: | - make lint + run: make lint + + test: + runs-on: ubuntu-latest + needs: lint + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: - name: Run tests - run: | - make test + run: make test + + upload-coverage: + runs-on: ubuntu-latest + needs: test + steps: - name: Upload test coverage run: codecov env: From 87f0c2383b847746ebd573c903cdc8312c243d22 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 3 Nov 2024 09:16:19 +0400 Subject: [PATCH 12/57] fix: "needs" dependency for test job --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c796b1..0d6b792 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,7 +48,7 @@ jobs: test: runs-on: ubuntu-latest - needs: lint + needs: install-dependencies strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] From efed05cff396f40ea1e585dd340003b6d411a5dd Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 3 Nov 2024 09:24:00 +0400 Subject: [PATCH 13/57] feat: added test analytics for codecov and updated ci jobs --- .github/workflows/ci.yml | 13 +++++++++---- .gitignore | 1 + setup.cfg | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d6b792..0dc0549 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,12 @@ jobs: runs-on: ubuntu-latest needs: test steps: - - name: Upload test coverage - run: codecov - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index edf536b..749f0d1 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +junit.xml *.cover .hypothesis/ .pytest_cache/ diff --git a/setup.cfg b/setup.cfg index 522a045..ab1ad96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts = -s --strict-markers -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail +addopts = -s --strict-markers -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail --junitxml=junit.xml -o junit_family=legacy [coverage:run] branch = True From 1421cef9f8c40a0292abda91cdd54f80655e0389 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 10 Nov 2024 13:16:26 +0400 Subject: [PATCH 14/57] fix: nosec --- .pre-commit-config.yaml | 11 ++++++----- aioauth/errors.py | 36 ++++++++++++++++++------------------ aioauth/grant_type.py | 30 +++++++++++++++--------------- aioauth/response_type.py | 20 ++++++++++---------- aioauth/server.py | 32 ++++++++++++++++---------------- aioauth/types.py | 19 ++++++++++++------- 6 files changed, 77 insertions(+), 71 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ac4649..e5ea067 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,12 +29,13 @@ repos: require_serial: true verbose: false exclude: ^(docs/|setup\.py) - - - repo: https://github.com/PyCQA/bandit - rev: 1.7.10 - hooks: - id: bandit - args: ["--skip", "B101"] + name: bandit + entry: "bandit" + language: system + types: [python] + require_serial: true + verbose: false - repo: https://github.com/pycqa/flake8 rev: 7.1.1 diff --git a/aioauth/errors.py b/aioauth/errors.py index 1bec33d..2837dd4 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -9,7 +9,7 @@ """ from http import HTTPStatus -from typing import Optional +from typing import Generic, Optional from urllib.parse import urljoin from typing_extensions import Literal @@ -17,10 +17,10 @@ from .collections import HTTPHeaderDict from .constances import default_headers -from .types import ErrorType +from .types import ErrorType, UserType -class OAuth2Error(Exception): +class OAuth2Error(Generic[UserType], Exception): """Base exception that all other exceptions inherit from.""" error: ErrorType @@ -32,7 +32,7 @@ class OAuth2Error(Exception): def __init__( self, - request: Request, + request: Request[UserType], description: Optional[str] = None, headers: Optional[HTTPHeaderDict] = None, state: Optional[str] = None, @@ -54,7 +54,7 @@ def __init__( super().__init__(f"({self.error}) {self.description}") -class MethodNotAllowedError(OAuth2Error): +class MethodNotAllowedError(Generic[UserType], OAuth2Error[UserType]): """ The request is valid, but the method trying to be accessed is not available to the resource owner. @@ -65,7 +65,7 @@ class MethodNotAllowedError(OAuth2Error): error: ErrorType = "method_is_not_allowed" -class InvalidRequestError(OAuth2Error): +class InvalidRequestError(Generic[UserType], OAuth2Error[UserType]): """ The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is @@ -75,7 +75,7 @@ class InvalidRequestError(OAuth2Error): error: Literal["invalid_request"] = "invalid_request" -class InvalidClientError(OAuth2Error): +class InvalidClientError(Generic[UserType], OAuth2Error[UserType]): """ Client authentication failed (e.g. unknown client, no client authentication included, or unsupported authentication method). @@ -110,14 +110,14 @@ def __init__( self.headers["WWW-Authenticate"] = "Basic " + ", ".join(auth_values) -class InsecureTransportError(OAuth2Error): +class InsecureTransportError(Generic[UserType], OAuth2Error[UserType]): """An exception will be thrown if the current request is not secure.""" description = "OAuth 2 MUST utilize https." error: ErrorType = "insecure_transport" -class UnsupportedGrantTypeError(OAuth2Error): +class UnsupportedGrantTypeError(Generic[UserType], OAuth2Error[UserType]): """ The authorization grant type is not supported by the authorization server. @@ -126,7 +126,7 @@ class UnsupportedGrantTypeError(OAuth2Error): error: ErrorType = "unsupported_grant_type" -class UnsupportedResponseTypeError(OAuth2Error): +class UnsupportedResponseTypeError(Generic[UserType], OAuth2Error[UserType]): """ The authorization server does not support obtaining an authorization code using this method. @@ -135,7 +135,7 @@ class UnsupportedResponseTypeError(OAuth2Error): error: ErrorType = "unsupported_response_type" -class InvalidGrantError(OAuth2Error): +class InvalidGrantError(Generic[UserType], OAuth2Error[UserType]): """ The provided authorization grant (e.g. authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does @@ -148,14 +148,14 @@ class InvalidGrantError(OAuth2Error): error: ErrorType = "invalid_grant" -class MismatchingStateError(OAuth2Error): +class MismatchingStateError(Generic[UserType], OAuth2Error[UserType]): """Unable to securely verify the integrity of the request and response.""" description = "CSRF Warning! State not equal in request and response." error: Literal["mismatching_state"] = "mismatching_state" -class UnauthorizedClientError(OAuth2Error): +class UnauthorizedClientError(Generic[UserType], OAuth2Error[UserType]): """ The authenticated client is not authorized to use this authorization grant type. @@ -164,7 +164,7 @@ class UnauthorizedClientError(OAuth2Error): error: ErrorType = "unauthorized_client" -class InvalidScopeError(OAuth2Error): +class InvalidScopeError(Generic[UserType], OAuth2Error[UserType]): """ The requested scope is invalid, unknown, or malformed, or exceeds the scope granted by the resource owner. @@ -175,7 +175,7 @@ class InvalidScopeError(OAuth2Error): error: ErrorType = "invalid_scope" -class ServerError(OAuth2Error): +class ServerError(Generic[UserType], OAuth2Error[UserType]): """ The authorization server encountered an unexpected condition that prevented it from fulfilling the request. (This error code is needed @@ -187,7 +187,7 @@ class ServerError(OAuth2Error): status_code: HTTPStatus = HTTPStatus.BAD_REQUEST -class TemporarilyUnavailableError(OAuth2Error): +class TemporarilyUnavailableError(Generic[UserType], OAuth2Error[UserType]): """ The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server. @@ -198,7 +198,7 @@ class TemporarilyUnavailableError(OAuth2Error): error: ErrorType = "temporarily_unavailable" -class InvalidRedirectURIError(OAuth2Error): +class InvalidRedirectURIError(Generic[UserType], OAuth2Error[UserType]): """ The requested redirect URI is missing or not allowed. """ @@ -206,7 +206,7 @@ class InvalidRedirectURIError(OAuth2Error): error: ErrorType = "invalid_request" -class UnsupportedTokenTypeError(OAuth2Error): +class UnsupportedTokenTypeError(Generic[UserType], OAuth2Error[UserType]): """ The authorization server does not support the revocation of the presented token type. That is, the client tried to revoke an access token on a server diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index fd07531..7924ed8 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -72,15 +72,15 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not client: - raise InvalidClientError( + raise InvalidClientError[UserType]( request=request, description="Invalid client_id parameter value." ) if not client.check_grant_type(request.post.grant_type): - raise UnauthorizedClientError(request=request) + raise UnauthorizedClientError[UserType](request=request) if not client.check_scope(request.post.scope): - raise InvalidScopeError(request=request) + raise InvalidScopeError[UserType](request=request) self.scope = request.post.scope return client @@ -106,17 +106,17 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] client = await super().validate_request(request) if not request.post.redirect_uri: - raise InvalidRedirectURIError( + raise InvalidRedirectURIError[UserType]( request=request, description="Mismatching redirect URI." ) if not client.check_redirect_uri(request.post.redirect_uri): - raise InvalidRedirectURIError( + raise InvalidRedirectURIError[UserType]( request=request, description="Invalid redirect URI." ) if not request.post.code: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Missing code parameter." ) @@ -125,14 +125,14 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not authorization_code: - raise InvalidGrantError(request=request) + raise InvalidGrantError[UserType](request=request) if ( authorization_code.code_challenge and authorization_code.code_challenge_method ): if not request.post.code_verifier: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Code verifier required." ) @@ -140,10 +140,10 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] request.post.code_verifier ) if not is_valid_code_challenge: - raise MismatchingStateError(request=request) + raise MismatchingStateError[UserType](request=request) if authorization_code.is_expired: - raise InvalidGrantError(request=request) + raise InvalidGrantError[UserType](request=request) self.scope = authorization_code.scope return client @@ -180,14 +180,14 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] client = await super().validate_request(request) if not request.post.username or not request.post.password: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Invalid credentials given." ) user = await self.storage.get_user(request) if user is None: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Invalid credentials given." ) @@ -214,7 +214,7 @@ async def create_token_response( ) if not old_token or old_token.revoked or old_token.refresh_token_expired: - raise InvalidGrantError(request=request) + raise InvalidGrantError[UserType](request=request) # Revoke old token await self.storage.revoke_token( @@ -250,7 +250,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] client = await super().validate_request(request) if not request.post.refresh_token: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Missing refresh token parameter." ) @@ -269,6 +269,6 @@ class ClientCredentialsGrantType(GrantTypeBase[UserType]): async def validate_request(self, request: Request[UserType]) -> Client[UserType]: # client_credentials grant requires a client_secret if self.client_secret is None: - raise InvalidClientError(request) + raise InvalidClientError[UserType](request) return await super().validate_request(request) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 23c2987..b73dc24 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -52,7 +52,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not request.query.client_id: - raise InvalidClientError( + raise InvalidClientError[UserType]( request=request, description="Missing client_id parameter.", state=state ) @@ -61,43 +61,43 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not client: - raise InvalidClientError( + raise InvalidClientError[UserType]( request=request, description="Invalid client_id parameter value.", state=state, ) if not request.query.redirect_uri: - raise InvalidRedirectURIError( + raise InvalidRedirectURIError[UserType]( request=request, description="Mismatching redirect URI.", state=state ) if not client.check_redirect_uri(request.query.redirect_uri): - raise InvalidRedirectURIError( + raise InvalidRedirectURIError[UserType]( request=request, description="Invalid redirect URI.", state=state ) if request.query.code_challenge_method: if request.query.code_challenge_method not in code_challenge_methods: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Transform algorithm not supported.", state=state, ) if not request.query.code_challenge: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Code challenge required.", state=state ) if not client.check_response_type(request.query.response_type): - raise UnsupportedResponseTypeError(request=request, state=state) + raise UnsupportedResponseTypeError[UserType](request=request, state=state) if not client.check_scope(request.query.scope): - raise InvalidScopeError(request=request, state=state) + raise InvalidScopeError[UserType](request=request, state=state) if not request.user: - raise InvalidClientError( + raise InvalidClientError[UserType]( request=request, description="User is not authorized", state=state ) @@ -156,7 +156,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] # nonce is required for id_token if not request.query.nonce: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Nonce required for response_type id_token.", state=request.query.state, diff --git a/aioauth/server.py b/aioauth/server.py index bf4ab30..744b887 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -128,16 +128,16 @@ def validate_request( self, request: Request[UserType], allowed_methods: List[RequestMethod] ): if not request.settings.AVAILABLE: - raise TemporarilyUnavailableError(request=request) + raise TemporarilyUnavailableError[UserType](request=request) if not self.is_secure_transport(request): - raise InsecureTransportError(request=request) + raise InsecureTransportError[UserType](request=request) if request.method not in allowed_methods: headers = HTTPHeaderDict( {**default_headers, "allow": ", ".join(allowed_methods)} ) - raise MethodNotAllowedError(request=request, headers=headers) + raise MethodNotAllowedError[UserType](request=request, headers=headers) @catch_errors_and_unavailability() async def create_token_introspection_response( @@ -183,7 +183,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError(request) + raise InvalidClientError[UserType](request) token_types: Tuple[TokenType, ...] = get_args(TokenType) token_type: TokenType = "refresh_token" @@ -194,7 +194,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: if request.post.token_type_hint in token_types: token_type = request.post.token_type_hint - if token_type == "access_token": + if token_type == "access_token": # nosec access_token = request.post.token refresh_token = None @@ -242,7 +242,7 @@ def get_client_credentials( if client_id is None or secret_required: # Either we didn't find a client ID at all, or we found # a client ID but no secret and a secret is required. - raise InvalidClientError( + raise InvalidClientError[UserType]( description="Invalid client_id parameter value.", request=request, ) from exc @@ -250,7 +250,7 @@ def get_client_credentials( # client_secret must not be None. When client_secret is None, # storage.get_client will not run standard checks on the client_secret if client_secret is None: - client_secret = "" + client_secret = "" # nosec return client_id, client_secret @@ -307,7 +307,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: if not request.post.grant_type: # grant_type request value is empty - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Request is missing grant type." ) @@ -325,7 +325,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: GrantTypeClass = self.grant_types[request.post.grant_type] except KeyError as exc: # grant_type request value is invalid - raise UnsupportedGrantTypeError(request=request) from exc + raise UnsupportedGrantTypeError[UserType](request=request) from exc grant_type = GrantTypeClass( storage=self.storage, client_id=client_id, client_secret=client_secret @@ -403,7 +403,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: state = request.query.state if not response_type_list: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Missing response_type parameter.", state=state, @@ -418,7 +418,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: response_type_classes.add(ResponseTypeClass) if not response_type_classes: - raise UnsupportedResponseTypeError(request=request, state=state) + raise UnsupportedResponseTypeError[UserType](request=request, state=state) for ResponseTypeClass in response_type_classes: response_type = ResponseTypeClass(storage=self.storage) @@ -512,10 +512,10 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError(request) + raise InvalidClientError[UserType](request) if not request.post.token: - raise InvalidRequestError( + raise InvalidRequestError[UserType]( request=request, description="Request is missing token." ) @@ -523,16 +523,16 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: "refresh_token", "access_token", }: - raise UnsupportedTokenTypeError(request=request) + raise UnsupportedTokenTypeError[UserType](request=request) access_token = ( request.post.token - if request.post.token_type_hint != "refresh_token" + if request.post.token_type_hint != "refresh_token" # nosec else None ) refresh_token = ( request.post.token - if request.post.token_type_hint != "access_token" + if request.post.token_type_hint != "access_token" # nosec else None ) diff --git a/aioauth/types.py b/aioauth/types.py index 54fb14c..7cc8dbd 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -16,13 +16,18 @@ else: from typing_extensions import TypeVar +if sys.version_info >= (3, 11): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal -ErrorType = Literal[ +ErrorType: TypeAlias = Literal[ "invalid_request", "invalid_client", "invalid_grant", @@ -39,7 +44,7 @@ ] -GrantType = Literal[ +GrantType: TypeAlias = Literal[ "authorization_code", "password", "client_credentials", @@ -47,7 +52,7 @@ ] -ResponseType = Literal[ +ResponseType: TypeAlias = Literal[ "token", "code", "none", @@ -55,22 +60,22 @@ ] -RequestMethod = Literal["GET", "POST"] +RequestMethod: TypeAlias = Literal["GET", "POST"] -CodeChallengeMethod = Literal[ +CodeChallengeMethod: TypeAlias = Literal[ "plain", "S256", ] -ResponseMode = Literal[ +ResponseMode: TypeAlias = Literal[ "query", "form_post", "fragment", ] -TokenType = Literal["access_token", "refresh_token", "Bearer"] +TokenType: TypeAlias = Literal["access_token", "refresh_token", "Bearer"] UserType = TypeVar("UserType", default=Any) From 129952a21c509cc1894163fc9da46bb8a88b08eb Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 10 Nov 2024 14:04:09 +0400 Subject: [PATCH 15/57] chore: migrated to pyproject.toml --- Makefile | 4 +- aioauth/__init__.py | 2 + aioauth/__version__.py | 8 --- pyproject.toml | 112 +++++++++++++++++++++++++++++++++++++++++ setup.cfg | 27 ---------- setup.py | 87 -------------------------------- 6 files changed, 116 insertions(+), 124 deletions(-) delete mode 100644 aioauth/__version__.py create mode 100644 pyproject.toml delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/Makefile b/Makefile index 975f658..a4a0516 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,7 @@ clean-pyc: ## remove Python file artifacts clean-test: ## remove test and coverage artifacts rm -f .coverage + rm -f coverage.xml rm -fr htmlcov/ rm -fr .pytest_cache @@ -56,8 +57,7 @@ release: dist ## package and upload a release twine upload dist/* dist: clean ## builds source and wheel package - python setup.py sdist - python setup.py bdist_wheel + python -m build ls -l dist install: clean ## install the package to the active Python's site-packages diff --git a/aioauth/__init__.py b/aioauth/__init__.py index 5a3f7fd..a3c5a6f 100644 --- a/aioauth/__init__.py +++ b/aioauth/__init__.py @@ -1,3 +1,5 @@ import logging +__version__ = "2.0.0" + logging.getLogger("aioauth").addHandler(logging.NullHandler()) diff --git a/aioauth/__version__.py b/aioauth/__version__.py deleted file mode 100644 index 5e81091..0000000 --- a/aioauth/__version__.py +++ /dev/null @@ -1,8 +0,0 @@ -__title__ = "aioauth" -__description__ = "Asynchronous OAuth 2.0 framework for Python 3." -__url__ = "https://github.com/aliev/aioauth" -__version__ = "1.7.0" -__author__ = "Ali Aliyev" -__author_email__ = "ali@aliev.me" -__license__ = "The MIT License (MIT)" -__copyright__ = "Copyright 2022 Ali Aliyev" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8fe452e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,112 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "aioauth" +dynamic = ["version"] +description = "Asynchronous OAuth 2.0 framework for Python 3." +readme = "README.md" +requires-python = ">=3.9.0" +authors = [ + { name = "Ali Aliyev", email = "ali@aliev.me" }, +] +classifiers = [ + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python", + "Topic :: Internet", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development", + "Typing :: Typed", + "Development Status :: 1 - Planning", + "Environment :: Web Environment", + "Framework :: AsyncIO", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Internet :: WWW/HTTP :: HTTP Servers", + "Topic :: Internet :: WWW/HTTP", +] +dependencies = [ + "typing_extensions" +] + +[project.optional-dependencies] +dev = [ + "build", + "twine", + "pytest", + "coverage", + "pytest-asyncio", + "mypy", + "bandit", + "pre-commit", +] + +docs = [ + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", + "m2r2", + "sphinx-rtd-theme", +] + +fastapi = [ + "aioauth-fastapi>=0.0.1" +] + +[project.urls] +homepage = "https://github.com/aliev/aioauth" + +[tool.setuptools.dynamic] +version = { attr = "aioauth.__version__" } + +[tool.setuptools.packages.find] +include = ["aioauth", "aioauth.*"] + +[tool.setuptools.package-data] +"aioauth" = ["py.typed"] + +[tool.pytest.ini_options] +addopts = "-s --strict-markers -vv --cache-clear --maxfail=1" +pythonpath = ["."] + +[tool.mypy] +python_version = "3.9" +warn_no_return = false +disallow_untyped_defs = false +allow_redefinition = true +namespace_packages = true +explicit_package_bases = true + +[tool.mypy-packages] +ignore_missing_imports = true + +[tool.flake8] +ignore = ["D10", "E203", "E501", "W503", "D205", "D400", "A001", "D210", "D401", "E701"] +max-line-length = 88 +select = ["A", "B", "C4", "D", "E", "F", "M", "Q", "T", "W", "ABS", "BLK"] +exclude = ["versions/*"] +inline-quotes = "\"" + +[tool.isort] +profile = "black" + +[tool.coverage.run] +relative_files = true +source = ["aioauth"] +branch = true + +[tool.coverage.report] +include = [ + "aioauth/*", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index ab1ad96..0000000 --- a/setup.cfg +++ /dev/null @@ -1,27 +0,0 @@ -[tool:pytest] -addopts = -s --strict-markers -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail --junitxml=junit.xml -o junit_family=legacy - -[coverage:run] -branch = True -omit = - site-packages - aioauth/__version__.py - -[mypy] -python_version = 3.8 -warn_no_return = False -disallow_untyped_defs = False -allow_redefinition = True - -[bdist_wheel] -universal = 1 - -[tool:isort] -profile = "black" - -[flake8] -ignore = D10,E203,E501,W503,D205,D400,A001,D210,D401,E701 -max-line-length = 88 -select = A,B,C4,D,E,F,M,Q,T,W,ABS,BLK -exclude = versions/* -inline-quotes = " diff --git a/setup.py b/setup.py deleted file mode 100644 index a28cdcc..0000000 --- a/setup.py +++ /dev/null @@ -1,87 +0,0 @@ -from pathlib import Path - -from setuptools import setup, find_namespace_packages - -here = Path(__file__).parent -about = {} - -with open(here / "aioauth" / "__version__.py", "r") as f: - exec(f.read(), about) - -with open("README.md") as readme_file: - readme = readme_file.read() - -classifiers = [ - "Intended Audience :: Information Technology", - "Intended Audience :: System Administrators", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python", - "Topic :: Internet", - "Topic :: Software Development :: Libraries :: Application Frameworks", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development", - "Typing :: Typed", - "Development Status :: 1 - Planning", - "Environment :: Web Environment", - "Framework :: AsyncIO", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Topic :: Internet :: WWW/HTTP :: HTTP Servers", - "Topic :: Internet :: WWW/HTTP", -] - -require_dev = [ - "async-asgi-testclient==1.4.8", - "backports.cached-property==1.0.2", - "pre_commit==4.0.1", - "pytest==6.2.5", - "pytest-asyncio==0.16.0", - "pytest-cov==3.0.0", - "pytest-env==0.6.2", - "pytest-sugar==0.9.4", - "testfixtures==6.18.3", - "twine==3.7.1", - "wheel", - "mypy==1.13.0", - "bandit==1.7.10", -] - -require_docs = [ - "sphinx", - "sphinx-copybutton", - "sphinx-autobuild", - "m2r2", - "sphinx-rtd-theme", -] - -setup( - name=about["__title__"], - version=about["__version__"], - description=about["__description__"], - long_description=readme, - long_description_content_type="text/markdown", - author=about["__author__"], - author_email=about["__author_email__"], - url=about["__url__"], - license=about["__license__"], - package_data={"aioauth": ["py.typed"]}, - python_requires=">=3.9.0", - classifiers=classifiers, - install_requires=["typing_extensions"], - extras_require={ - "fastapi": ["aioauth-fastapi>=0.0.1"], - "dev": require_dev, - "docs": require_docs + require_dev, - }, - include_package_data=True, - keywords="asyncio oauth2 oauth", - packages=find_namespace_packages(include=["aioauth", "aioauth.*"]), - project_urls={"Source": about["__url__"]}, -) From 56bfddcf86cd06946eedbc37d293ba96bac83fbe Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 16 Nov 2024 11:28:21 +0400 Subject: [PATCH 16/57] fix: trying to fix the pipelines --- .github/workflows/ci.yml | 37 +++++-------------------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0dc0549..8cd6ee9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,63 +7,36 @@ on: branches: [ master ] jobs: - checkout: + test-and-lint: runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - setup-python: - runs-on: ubuntu-latest - needs: checkout strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - install-dependencies: - runs-on: ubuntu-latest - needs: setup-python - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - steps: - name: Install dependencies run: | + python -m pip install --upgrade pip make dev-install pip install codecov - lint: - runs-on: ubuntu-latest - needs: install-dependencies - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - steps: - name: Run lint run: make lint - test: - runs-on: ubuntu-latest - needs: install-dependencies - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - steps: - name: Run tests run: make test - upload-coverage: - runs-on: ubuntu-latest - needs: test - steps: - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov if: ${{ !cancelled() }} uses: codecov/test-results-action@v1 From 5ae65da135bcdd2df1bc2128e4cd921f65a504c1 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 16 Nov 2024 11:32:28 +0400 Subject: [PATCH 17/57] fix: temporary disables bandit and fixing steps --- .github/workflows/ci.yml | 2 +- .pre-commit-config.yaml | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8cd6ee9..7d10710 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [ master ] jobs: - test-and-lint: + install-and-test: runs-on: ubuntu-latest strategy: matrix: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5ea067..518cf49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,23 +19,23 @@ repos: - id: check-merge-conflict - id: detect-private-key - - repo: local - hooks: - - id: mypy - name: mypy - entry: "mypy" - language: system - types: [python] - require_serial: true - verbose: false - exclude: ^(docs/|setup\.py) - - id: bandit - name: bandit - entry: "bandit" - language: system - types: [python] - require_serial: true - verbose: false + # - repo: local + # hooks: + # - id: mypy + # name: mypy + # entry: "mypy" + # language: system + # types: [python] + # require_serial: true + # verbose: false + # exclude: ^(docs/|setup\.py) + # - id: bandit + # name: bandit + # entry: "bandit" + # language: system + # types: [python] + # require_serial: true + # verbose: false - repo: https://github.com/pycqa/flake8 rev: 7.1.1 From 117397cfedb0eeab76bb96956eb1f5412ee5097b Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 02:20:42 +0400 Subject: [PATCH 18/57] chore: added typed arguments for storages --- .pre-commit-config.yaml | 23 +++--- aioauth/grant_type.py | 33 +++++--- aioauth/oidc/core/grant_type.py | 16 ++-- aioauth/response_type.py | 22 +++--- aioauth/server.py | 4 +- aioauth/storage.py | 129 ++++++++++++++++++++++---------- tests/classes.py | 99 ++++++++++++------------ tests/test_db.py | 8 +- tests/test_flow.py | 16 +++- 9 files changed, 215 insertions(+), 135 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 518cf49..e051583 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,16 +19,16 @@ repos: - id: check-merge-conflict - id: detect-private-key - # - repo: local - # hooks: - # - id: mypy - # name: mypy - # entry: "mypy" - # language: system - # types: [python] - # require_serial: true - # verbose: false - # exclude: ^(docs/|setup\.py) + - repo: local + hooks: + - id: mypy + name: mypy + entry: "mypy" + language: system + types: [python] + require_serial: true + verbose: false + exclude: ^(docs/|setup\.py) # - id: bandit # name: bandit # entry: "bandit" @@ -40,4 +40,5 @@ repos: - repo: https://github.com/pycqa/flake8 rev: 7.1.1 hooks: - - id: flake8 + - id: flake8 + additional_dependencies: [flake8-pyproject] diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 7924ed8..9b9bb6b 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -49,11 +49,11 @@ async def create_token_response( raise RuntimeError("validate_request() must be called first") token = await self.storage.create_token( - request, - client.client_id, - self.scope, - generate_token(42), - generate_token(48), + request=request, + client_id=client.client_id, + scope=self.scope, + access_token=generate_token(42), + refresh_token=generate_token(48), ) return TokenResponse( @@ -68,7 +68,7 @@ async def create_token_response( async def validate_request(self, request: Request[UserType]) -> Client[UserType]: """Validates the client request to ensure it is valid.""" client = await self.storage.get_client( - request, client_id=self.client_id, client_secret=self.client_secret + request=request, client_id=self.client_id, client_secret=self.client_secret ) if not client: @@ -121,7 +121,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) authorization_code = await self.storage.get_authorization_code( - request, client.client_id, request.post.code + request=request, client_id=client.client_id, code=request.post.code ) if not authorization_code: @@ -157,9 +157,9 @@ async def create_token_response( raise await self.storage.delete_authorization_code( - request, - client.client_id, - request.post.code, + request=request, + client_id=client.client_id, + code=request.post.code, ) return token_response @@ -211,6 +211,8 @@ async def create_token_response( request=request, client_id=client.client_id, refresh_token=request.post.refresh_token, + access_token=None, + token_type="refresh_token", ) if not old_token or old_token.revoked or old_token.refresh_token_expired: @@ -218,7 +220,10 @@ async def create_token_response( # Revoke old token await self.storage.revoke_token( - request=request, refresh_token=old_token.refresh_token + request=request, + refresh_token=old_token.refresh_token, + token_type="refresh_token", + access_token=None, ) # new token should have at max the same scope as the old token @@ -234,7 +239,11 @@ async def create_token_response( ) token = await self.storage.create_token( - request, client.client_id, new_scope, generate_token(42), generate_token(48) + request=request, + client_id=client.client_id, + scope=new_scope, + access_token=generate_token(42), + refresh_token=generate_token(48), ) return TokenResponse( diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index 2c4a5a8..3828658 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -50,11 +50,11 @@ async def create_token_response( raise RuntimeError("validate_request() must be called first") token = await self.storage.create_token( - request, - client.client_id, - self.scope, - generate_token(42), - generate_token(48), + request=request, + client_id=client.client_id, + scope=self.scope, + access_token=generate_token(42), + refresh_token=generate_token(48), ) if TYPE_CHECKING: @@ -81,9 +81,9 @@ async def create_token_response( ) await self.storage.delete_authorization_code( - request, - client.client_id, - request.post.code, + request=request, + client_id=client.client_id, + code=request.post.code, ) return TokenResponse( diff --git a/aioauth/response_type.py b/aioauth/response_type.py index b73dc24..a6dc5cb 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -111,11 +111,11 @@ async def create_authorization_response( self, request: Request[UserType], client: Client[UserType] ) -> TokenResponse: token = await self.storage.create_token( - request, - client.client_id, - request.query.scope, - generate_token(42), - generate_token(48), + request=request, + client_id=client.client_id, + scope=request.query.scope, + access_token=generate_token(42), + refresh_token=generate_token(48), ) return TokenResponse( expires_in=token.expires_in, @@ -167,12 +167,12 @@ async def create_authorization_response( self, request: Request[UserType], client: Client[UserType] ) -> IdTokenResponse: id_token = await self.storage.get_id_token( - request, - client.client_id, - request.query.scope, - request.query.response_type, # type: ignore - request.query.redirect_uri, - nonce=request.query.nonce, # type: ignore + request=request, + client_id=client.client_id, + scope=request.query.scope, + response_type=request.query.response_type, + redirect_uri=request.query.redirect_uri, + nonce=request.query.nonce, ) return IdTokenResponse(id_token=id_token) diff --git a/aioauth/server.py b/aioauth/server.py index 744b887..f68dffd 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -179,7 +179,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) client = await self.storage.get_client( - request, client_id=client_id, client_secret=client_secret + request=request, client_id=client_id, client_secret=client_secret ) if not client: @@ -508,7 +508,7 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: ) client = await self.storage.get_client( - request, client_id=client_id, client_secret=client_secret + request=request, client_id=client_id, client_secret=client_secret ) if not client: diff --git a/aioauth/storage.py b/aioauth/storage.py index 1344603..2dffcdf 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -10,23 +10,95 @@ ---- """ -from typing import Optional, Generic +import sys +from typing import TYPE_CHECKING, Optional, Generic from .models import AuthorizationCode, Client, Token -from .types import CodeChallengeMethod, ResponseType, TokenType +from .types import CodeChallengeMethod, TokenType from .requests import Request from .types import UserType +if sys.version_info >= (3, 11): + from typing import Unpack, NotRequired +else: + from typing_extensions import Unpack, NotRequired + +if sys.version_info >= (3, 11): + from typing import TypedDict +else: + from typing_extensions import TypedDict as _TypedDict + + # NOTE: workaround for Python < 3.11 + # https://github.com/python/cpython/issues/89026 + if TYPE_CHECKING: + + class TypedDict(Generic[UserType], _TypedDict): ... + + else: + + class TypedDict(Generic[UserType]): ... + + +class AuthorizationCodeGet(TypedDict[UserType]): + request: Request[UserType] + client_id: str + code: str + + +class ClientStorageGetClient(TypedDict[UserType]): + request: Request[UserType] + client_id: str + client_secret: NotRequired[Optional[str]] + + +class IDTokenGetIdToken(TypedDict[UserType]): + request: Request[UserType] + client_id: str + scope: str + response_type: Optional[str] + redirect_uri: str + nonce: Optional[str] + + +class ArgsAuthorizationCode(TypedDict[UserType]): + request: Request[UserType] + client_id: str + scope: str + response_type: str + redirect_uri: str + code_challenge_method: Optional[CodeChallengeMethod] + code_challenge: Optional[str] + code: str + nonce: NotRequired[Optional[str]] + + +class TokenStorageCreateToken(TypedDict[UserType]): + request: Request[UserType] + client_id: str + scope: str + access_token: str + refresh_token: str + + +class TokenStorageGetToken(TypedDict[UserType]): + request: Request[UserType] + client_id: str + token_type: Optional[TokenType] # default is "refresh_token" + access_token: Optional[str] # default is None + refresh_token: Optional[str] # default is None + + +class TokenStorageRevokeToken(TypedDict[UserType]): + request: Request[UserType] + refresh_token: Optional[str] + token_type: Optional[TokenType] + access_token: Optional[str] + class TokenStorage(Generic[UserType]): async def create_token( - self, - request: Request[UserType], - client_id: str, - scope: str, - access_token: str, - refresh_token: str, + self, **kwargs: Unpack[TokenStorageCreateToken[UserType]] ) -> Token: """Generates a user token and stores it in the database. @@ -52,12 +124,7 @@ async def create_token( raise NotImplementedError("Method create_token must be implemented") async def get_token( - self, - request: Request[UserType], - client_id: str, - token_type: Optional[TokenType] = "refresh_token", - access_token: Optional[str] = None, - refresh_token: Optional[str] = None, + self, **kwargs: Unpack[TokenStorageGetToken[UserType]] ) -> Optional[Token]: """Gets existing token from the database. @@ -76,11 +143,7 @@ async def get_token( raise NotImplementedError("Method get_token must be implemented") async def revoke_token( - self, - request: Request[UserType], - token_type: Optional[TokenType] = "refresh_token", - access_token: Optional[str] = None, - refresh_token: Optional[str] = None, + self, **kwargs: Unpack[TokenStorageRevokeToken[UserType]] ) -> None: """Revokes a token from the database.""" raise NotImplementedError @@ -89,15 +152,7 @@ async def revoke_token( class AuthorizationCodeStorage(Generic[UserType]): async def create_authorization_code( self, - request: Request[UserType], - client_id: str, - scope: str, - response_type: ResponseType, - redirect_uri: str, - code_challenge_method: Optional[CodeChallengeMethod], - code_challenge: Optional[str], - code: str, - **kwargs, + **kwargs: Unpack[ArgsAuthorizationCode[UserType]], ) -> AuthorizationCode: """Generates an authorization token and stores it in the database. @@ -122,7 +177,8 @@ async def create_authorization_code( ) async def get_authorization_code( - self, request: Request[UserType], client_id: str, code: str + self, + **kwargs: Unpack[AuthorizationCodeGet[UserType]], ) -> Optional[AuthorizationCode]: """Gets existing authorization code from the database if it exists. @@ -145,7 +201,8 @@ async def get_authorization_code( ) async def delete_authorization_code( - self, request: Request[UserType], client_id: str, code: str + self, + **kwargs: Unpack[AuthorizationCodeGet[UserType]], ) -> None: """Deletes authorization code from database. @@ -165,9 +222,7 @@ async def delete_authorization_code( class ClientStorage(Generic[UserType]): async def get_client( self, - request: Request[UserType], - client_id: str, - client_secret: Optional[str] = None, + **kwargs: Unpack[ClientStorageGetClient[UserType]], ) -> Optional[Client[UserType]]: """Gets existing client from the database if it exists. @@ -207,13 +262,7 @@ async def get_user(self, request: Request[UserType]) -> Optional[UserType]: class IDTokenStorage(Generic[UserType]): async def get_id_token( self, - request: Request[UserType], - client_id: str, - scope: str, - response_type: ResponseType, - redirect_uri: str, - nonce: Optional[str], - **kwargs, + **kwargs: Unpack[IDTokenGetIdToken[UserType]], ) -> str: """Returns an id_token. For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 `_. diff --git a/tests/classes.py b/tests/classes.py index 13e4577..3678baf 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -11,14 +11,28 @@ from aioauth.requests import Request from aioauth.response_type import ResponseTypeBase from aioauth.server import AuthorizationServer -from aioauth.storage import BaseStorage -from aioauth.types import CodeChallengeMethod, GrantType, ResponseType, TokenType +from aioauth.storage import ( + ArgsAuthorizationCode, + AuthorizationCodeGet, + BaseStorage, + ClientStorageGetClient, + IDTokenGetIdToken, + TokenStorageCreateToken, + TokenStorageGetToken, + TokenStorageRevokeToken, +) +from aioauth.types import GrantType, ResponseType if sys.version_info >= (3, 8): from functools import cached_property else: from backports.cached_property import cached_property +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + @dataclass(frozen=True) class User: @@ -50,23 +64,22 @@ def _get_by_client_id(self, client_id: str): async def get_client( self, - request: Request[User], - client_id: str, - client_secret: Optional[str] = None, + **kwargs: Unpack[ClientStorageGetClient[User]], ) -> Optional[Client]: + client_secret = kwargs.get("client_secret") + client_id = kwargs["client_id"] + if client_secret is not None: return self._get_by_client_secret(client_id, client_secret) return self._get_by_client_id(client_id) - async def create_token( - self, - request: Request[User], - client_id: str, - scope: str, - access_token: str, - refresh_token: str, - ): + async def create_token(self, **kwargs: Unpack[TokenStorageCreateToken[User]]): + client_id = kwargs["client_id"] + request = kwargs["request"] + access_token = kwargs["access_token"] + refresh_token = kwargs["refresh_token"] + scope = kwargs["scope"] token: Token[User] = Token( client_id=client_id, expires_in=request.settings.TOKEN_EXPIRES_IN, @@ -81,13 +94,11 @@ async def create_token( return token async def revoke_token( - self, - request: Request[User], - token_type: Optional[TokenType] = "refresh_token", - access_token: Optional[str] = None, - refresh_token: Optional[str] = None, + self, **kwargs: Unpack[TokenStorageRevokeToken[User]] ) -> None: tokens = self.tokens + refresh_token = kwargs["refresh_token"] + access_token = kwargs["access_token"] for key, token_ in enumerate(tokens): if token_.refresh_token == refresh_token: tokens[key] = replace(token_, revoked=True) @@ -95,13 +106,12 @@ async def revoke_token( tokens[key] = replace(token_, revoked=True) async def get_token( - self, - request: Request[User], - client_id: str, - token_type: Optional[TokenType] = "refresh_token", - access_token: Optional[str] = None, - refresh_token: Optional[str] = None, + self, **kwargs: Unpack[TokenStorageGetToken[User]] ) -> Optional[Token]: + refresh_token = kwargs["refresh_token"] + access_token = kwargs["access_token"] + client_id = kwargs["client_id"] + for token_ in self.tokens: if ( refresh_token is not None @@ -130,17 +140,17 @@ async def get_user(self, request: Request[User]) -> Optional[User]: async def create_authorization_code( self, - request: Request[User], - client_id: str, - scope: str, - response_type: str, - redirect_uri: str, - code_challenge_method: Optional[CodeChallengeMethod], - code_challenge: Optional[str], - code: str, - **kwargs, + **kwargs: Unpack[ArgsAuthorizationCode[User]], ): + request = kwargs["request"] nonce = kwargs.get("nonce") + code = kwargs["code"] + client_id = kwargs["client_id"] + redirect_uri = kwargs["redirect_uri"] + response_type = kwargs["response_type"] + scope = kwargs["scope"] + code_challenge_method = kwargs["code_challenge_method"] + code_challenge = kwargs["code_challenge"] authorization_code = AuthorizationCode( code=code, client_id=client_id, @@ -158,8 +168,12 @@ async def create_authorization_code( return authorization_code async def get_authorization_code( - self, request: Request[User], client_id: str, code: str + self, + **kwargs: Unpack[AuthorizationCodeGet[User]], ) -> Optional[AuthorizationCode]: + code = kwargs["code"] + client_id = kwargs["client_id"] + for authorization_code in self.authorization_codes: if ( authorization_code.code == code @@ -169,10 +183,10 @@ async def get_authorization_code( async def delete_authorization_code( self, - request: Request[User], - client_id: str, - code: str, + **kwargs: Unpack[AuthorizationCodeGet[User]], ): + code = kwargs["code"] + client_id = kwargs["client_id"] authorization_codes = self.authorization_codes for authorization_code in authorization_codes: if ( @@ -181,16 +195,7 @@ async def delete_authorization_code( ): authorization_codes.remove(authorization_code) - async def get_id_token( - self, - request: Request[User], - client_id: str, - scope: str, - response_type: ResponseType, - redirect_uri: str, - nonce: Optional[str], - **kwargs, - ) -> str: + async def get_id_token(self, **kwargs: Unpack[IDTokenGetIdToken[User]]) -> str: return "generated id token" diff --git a/tests/test_db.py b/tests/test_db.py index d45bbad..355eff7 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -43,6 +43,7 @@ async def test_storage_class() -> None: client_id=client.client_id, access_token=token.access_token, refresh_token=token.refresh_token, + token_type="refresh_token", ) with pytest.raises(NotImplementedError): await db.get_client( @@ -61,7 +62,12 @@ async def test_storage_class() -> None: request=request, client_id=client.client_id, code=authorization_code.code ) with pytest.raises(NotImplementedError): - await db.revoke_token(request=request, refresh_token=token.refresh_token) + await db.revoke_token( + request=request, + refresh_token=token.refresh_token, + token_type=None, + access_token=None, + ) with pytest.raises(NotImplementedError): await db.get_id_token( diff --git a/tests/test_flow.py b/tests/test_flow.py index 69b4923..26efafc 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -57,7 +57,9 @@ async def test_authorization_code_flow_plain_code_challenge(): location = urlparse(location) query = dict(parse_qsl(location.query)) assert query["scope"] == scope - assert await db.get_authorization_code(request, client_id, query["code"]) + assert await db.get_authorization_code( + request=request, client_id=client_id, code=query["code"] + ) assert "code" in query location = response.headers["location"] @@ -93,6 +95,7 @@ async def test_authorization_code_flow_plain_code_challenge(): client_id=client_id, access_token=response.content["access_token"], refresh_token=response.content["refresh_token"], + token_type="Bearer", ) access_token = response.content["access_token"] @@ -122,6 +125,7 @@ async def test_authorization_code_flow_plain_code_challenge(): client_id=client_id, access_token=response.content["access_token"], refresh_token=response.content["refresh_token"], + token_type="access_token", ) # Check that previous token was revoken token_in_db = await db.get_token( @@ -129,6 +133,7 @@ async def test_authorization_code_flow_plain_code_challenge(): client_id=client_id, access_token=access_token, refresh_token=refresh_token, + token_type="access_token", ) assert token_in_db.revoked # type: ignore @@ -138,6 +143,7 @@ async def test_authorization_code_flow_plain_code_challenge(): client_id=client_id, access_token=response.content["access_token"], refresh_token=response.content["refresh_token"], + token_type="access_token", ) assert set(enforce_list(new_token.scope)) == set(enforce_list(token_in_db.scope)) # type: ignore @@ -201,7 +207,9 @@ async def test_authorization_code_flow_pkce_code_challenge(): await check_request_validators(request, server.create_token_response) - code_record = await db.get_authorization_code(request, client_id, code) + code_record = await db.get_authorization_code( + request=request, client_id=client_id, code=code + ) assert code_record response = await server.create_token_response(request) @@ -210,7 +218,9 @@ async def test_authorization_code_flow_pkce_code_challenge(): assert response.content["scope"] == scope assert response.content["token_type"] == "Bearer" - code_record = await db.get_authorization_code(request, client_id, code) + code_record = await db.get_authorization_code( + request=request, client_id=client_id, code=code + ) assert not code_record From 69a0b30747ac24537bdd9a4eeca9a83d40ded63a Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 02:28:02 +0400 Subject: [PATCH 19/57] fix: set the naming convention for Unpack arguments --- aioauth/storage.py | 34 +++++++++++++++------------------- tests/classes.py | 34 +++++++++++++++------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/aioauth/storage.py b/aioauth/storage.py index 2dffcdf..59c32be 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -40,19 +40,19 @@ class TypedDict(Generic[UserType], _TypedDict): ... class TypedDict(Generic[UserType]): ... -class AuthorizationCodeGet(TypedDict[UserType]): +class GetAuthorizationCodeArgs(TypedDict[UserType]): request: Request[UserType] client_id: str code: str -class ClientStorageGetClient(TypedDict[UserType]): +class GetClientArgs(TypedDict[UserType]): request: Request[UserType] client_id: str client_secret: NotRequired[Optional[str]] -class IDTokenGetIdToken(TypedDict[UserType]): +class GetIdTokenArgs(TypedDict[UserType]): request: Request[UserType] client_id: str scope: str @@ -61,7 +61,7 @@ class IDTokenGetIdToken(TypedDict[UserType]): nonce: Optional[str] -class ArgsAuthorizationCode(TypedDict[UserType]): +class CreateAuthorizationCodeArgs(TypedDict[UserType]): request: Request[UserType] client_id: str scope: str @@ -73,7 +73,7 @@ class ArgsAuthorizationCode(TypedDict[UserType]): nonce: NotRequired[Optional[str]] -class TokenStorageCreateToken(TypedDict[UserType]): +class CreateTokenArgs(TypedDict[UserType]): request: Request[UserType] client_id: str scope: str @@ -81,7 +81,7 @@ class TokenStorageCreateToken(TypedDict[UserType]): refresh_token: str -class TokenStorageGetToken(TypedDict[UserType]): +class GetTokenArgs(TypedDict[UserType]): request: Request[UserType] client_id: str token_type: Optional[TokenType] # default is "refresh_token" @@ -89,7 +89,7 @@ class TokenStorageGetToken(TypedDict[UserType]): refresh_token: Optional[str] # default is None -class TokenStorageRevokeToken(TypedDict[UserType]): +class RevokeTokenArgs(TypedDict[UserType]): request: Request[UserType] refresh_token: Optional[str] token_type: Optional[TokenType] @@ -97,9 +97,7 @@ class TokenStorageRevokeToken(TypedDict[UserType]): class TokenStorage(Generic[UserType]): - async def create_token( - self, **kwargs: Unpack[TokenStorageCreateToken[UserType]] - ) -> Token: + async def create_token(self, **kwargs: Unpack[CreateTokenArgs[UserType]]) -> Token: """Generates a user token and stores it in the database. Used by: @@ -124,7 +122,7 @@ async def create_token( raise NotImplementedError("Method create_token must be implemented") async def get_token( - self, **kwargs: Unpack[TokenStorageGetToken[UserType]] + self, **kwargs: Unpack[GetTokenArgs[UserType]] ) -> Optional[Token]: """Gets existing token from the database. @@ -142,9 +140,7 @@ async def get_token( """ raise NotImplementedError("Method get_token must be implemented") - async def revoke_token( - self, **kwargs: Unpack[TokenStorageRevokeToken[UserType]] - ) -> None: + async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[UserType]]) -> None: """Revokes a token from the database.""" raise NotImplementedError @@ -152,7 +148,7 @@ async def revoke_token( class AuthorizationCodeStorage(Generic[UserType]): async def create_authorization_code( self, - **kwargs: Unpack[ArgsAuthorizationCode[UserType]], + **kwargs: Unpack[CreateAuthorizationCodeArgs[UserType]], ) -> AuthorizationCode: """Generates an authorization token and stores it in the database. @@ -178,7 +174,7 @@ async def create_authorization_code( async def get_authorization_code( self, - **kwargs: Unpack[AuthorizationCodeGet[UserType]], + **kwargs: Unpack[GetAuthorizationCodeArgs[UserType]], ) -> Optional[AuthorizationCode]: """Gets existing authorization code from the database if it exists. @@ -202,7 +198,7 @@ async def get_authorization_code( async def delete_authorization_code( self, - **kwargs: Unpack[AuthorizationCodeGet[UserType]], + **kwargs: Unpack[GetAuthorizationCodeArgs[UserType]], ) -> None: """Deletes authorization code from database. @@ -222,7 +218,7 @@ async def delete_authorization_code( class ClientStorage(Generic[UserType]): async def get_client( self, - **kwargs: Unpack[ClientStorageGetClient[UserType]], + **kwargs: Unpack[GetClientArgs[UserType]], ) -> Optional[Client[UserType]]: """Gets existing client from the database if it exists. @@ -262,7 +258,7 @@ async def get_user(self, request: Request[UserType]) -> Optional[UserType]: class IDTokenStorage(Generic[UserType]): async def get_id_token( self, - **kwargs: Unpack[IDTokenGetIdToken[UserType]], + **kwargs: Unpack[GetIdTokenArgs[UserType]], ) -> str: """Returns an id_token. For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 `_. diff --git a/tests/classes.py b/tests/classes.py index 3678baf..7ded610 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -12,14 +12,14 @@ from aioauth.response_type import ResponseTypeBase from aioauth.server import AuthorizationServer from aioauth.storage import ( - ArgsAuthorizationCode, - AuthorizationCodeGet, BaseStorage, - ClientStorageGetClient, - IDTokenGetIdToken, - TokenStorageCreateToken, - TokenStorageGetToken, - TokenStorageRevokeToken, + CreateAuthorizationCodeArgs, + CreateTokenArgs, + GetAuthorizationCodeArgs, + GetClientArgs, + GetIdTokenArgs, + GetTokenArgs, + RevokeTokenArgs, ) from aioauth.types import GrantType, ResponseType @@ -64,7 +64,7 @@ def _get_by_client_id(self, client_id: str): async def get_client( self, - **kwargs: Unpack[ClientStorageGetClient[User]], + **kwargs: Unpack[GetClientArgs[User]], ) -> Optional[Client]: client_secret = kwargs.get("client_secret") client_id = kwargs["client_id"] @@ -74,7 +74,7 @@ async def get_client( return self._get_by_client_id(client_id) - async def create_token(self, **kwargs: Unpack[TokenStorageCreateToken[User]]): + async def create_token(self, **kwargs: Unpack[CreateTokenArgs[User]]): client_id = kwargs["client_id"] request = kwargs["request"] access_token = kwargs["access_token"] @@ -93,9 +93,7 @@ async def create_token(self, **kwargs: Unpack[TokenStorageCreateToken[User]]): self.tokens.append(token) return token - async def revoke_token( - self, **kwargs: Unpack[TokenStorageRevokeToken[User]] - ) -> None: + async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[User]]) -> None: tokens = self.tokens refresh_token = kwargs["refresh_token"] access_token = kwargs["access_token"] @@ -105,9 +103,7 @@ async def revoke_token( elif token_.access_token == access_token: tokens[key] = replace(token_, revoked=True) - async def get_token( - self, **kwargs: Unpack[TokenStorageGetToken[User]] - ) -> Optional[Token]: + async def get_token(self, **kwargs: Unpack[GetTokenArgs[User]]) -> Optional[Token]: refresh_token = kwargs["refresh_token"] access_token = kwargs["access_token"] client_id = kwargs["client_id"] @@ -140,7 +136,7 @@ async def get_user(self, request: Request[User]) -> Optional[User]: async def create_authorization_code( self, - **kwargs: Unpack[ArgsAuthorizationCode[User]], + **kwargs: Unpack[CreateAuthorizationCodeArgs[User]], ): request = kwargs["request"] nonce = kwargs.get("nonce") @@ -169,7 +165,7 @@ async def create_authorization_code( async def get_authorization_code( self, - **kwargs: Unpack[AuthorizationCodeGet[User]], + **kwargs: Unpack[GetAuthorizationCodeArgs[User]], ) -> Optional[AuthorizationCode]: code = kwargs["code"] client_id = kwargs["client_id"] @@ -183,7 +179,7 @@ async def get_authorization_code( async def delete_authorization_code( self, - **kwargs: Unpack[AuthorizationCodeGet[User]], + **kwargs: Unpack[GetAuthorizationCodeArgs[User]], ): code = kwargs["code"] client_id = kwargs["client_id"] @@ -195,7 +191,7 @@ async def delete_authorization_code( ): authorization_codes.remove(authorization_code) - async def get_id_token(self, **kwargs: Unpack[IDTokenGetIdToken[User]]) -> str: + async def get_id_token(self, **kwargs: Unpack[GetIdTokenArgs[User]]) -> str: return "generated id token" From 01c1f3402d9f8be8f67635c2590794810dfcf30e Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 02:28:13 +0400 Subject: [PATCH 20/57] fix: set the default fixture loop to fix the warning --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8fe452e..5988e78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,8 @@ include = ["aioauth", "aioauth.*"] [tool.pytest.ini_options] addopts = "-s --strict-markers -vv --cache-clear --maxfail=1" pythonpath = ["."] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.mypy] python_version = "3.9" From 986f45b6b2c2cecbf268f38fd32d097abcafc7d3 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 10:16:18 +0400 Subject: [PATCH 21/57] chore: removed redundant if sys.version_info... --- aioauth/response_type.py | 8 +------- aioauth/server.py | 8 +------- aioauth/storage.py | 13 ++++++------- aioauth/types.py | 7 +------ tests/classes.py | 6 +----- 5 files changed, 10 insertions(+), 32 deletions(-) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index a6dc5cb..08aca2e 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -8,18 +8,12 @@ ---- """ -import sys -from typing import Generic, Tuple +from typing import Generic, Tuple, get_args from .requests import Request from .types import UserType from .storage import BaseStorage -if sys.version_info >= (3, 8): - from typing import get_args -else: - from typing_extensions import get_args - from .utils import generate_token from .errors import ( InvalidClientError, diff --git a/aioauth/server.py b/aioauth/server.py index f68dffd..c6e89b0 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -17,21 +17,15 @@ ---- """ -import sys from dataclasses import asdict from http import HTTPStatus -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union, get_args from .requests import Request from .types import UserType from .storage import BaseStorage -if sys.version_info >= (3, 8): - from typing import get_args -else: - from typing_extensions import get_args - from .collections import HTTPHeaderDict from .constances import default_headers from .errors import ( diff --git a/aioauth/storage.py b/aioauth/storage.py index 59c32be..afd8586 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -20,14 +20,13 @@ from .types import UserType if sys.version_info >= (3, 11): - from typing import Unpack, NotRequired + from typing import Unpack, NotRequired, TypedDict else: - from typing_extensions import Unpack, NotRequired - -if sys.version_info >= (3, 11): - from typing import TypedDict -else: - from typing_extensions import TypedDict as _TypedDict + from typing_extensions import ( + TypedDict as _TypedDict, + Unpack, + NotRequired, + ) # NOTE: workaround for Python < 3.11 # https://github.com/python/cpython/issues/89026 diff --git a/aioauth/types.py b/aioauth/types.py index 7cc8dbd..f7defe2 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -9,7 +9,7 @@ """ import sys -from typing import Any +from typing import Any, Literal if sys.version_info >= (3, 13): from typing import TypeVar @@ -21,11 +21,6 @@ else: from typing_extensions import TypeAlias -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - ErrorType: TypeAlias = Literal[ "invalid_request", diff --git a/tests/classes.py b/tests/classes.py index 7ded610..5cf1fb0 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -2,6 +2,7 @@ import sys from typing import Dict, List, Optional, Type +from functools import cached_property from dataclasses import replace, dataclass @@ -23,11 +24,6 @@ ) from aioauth.types import GrantType, ResponseType -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from backports.cached_property import cached_property - if sys.version_info >= (3, 11): from typing import Unpack else: From b5950e8f56041a9bbc7bbd9dc0bcd79ae9b2ab61 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 11:09:12 +0400 Subject: [PATCH 22/57] fix: rolled back ci.yml changes --- .github/workflows/ci.yml | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d10710..fa6c7f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,3 +1,6 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + name: CI on: @@ -7,38 +10,28 @@ on: branches: [ master ] jobs: - install-and-test: + build: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies run: | - python -m pip install --upgrade pip make dev-install pip install codecov - - name: Run lint - run: make lint - + run: | + make lint - name: Run tests - run: make test - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - - - name: Upload test results to Codecov - if: ${{ !cancelled() }} - uses: codecov/test-results-action@v1 - with: - token: ${{ secrets.CODECOV_TOKEN }} + run: | + make test + - name: Upload test coverage + run: codecov + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} From e30e12fadff3f5fa297d9b6b35cdd3900f4f0733 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 17 Nov 2024 11:25:16 +0400 Subject: [PATCH 23/57] fix: TypedDict generic workaround --- aioauth/storage.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/aioauth/storage.py b/aioauth/storage.py index afd8586..4b779a4 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -20,23 +20,21 @@ from .types import UserType if sys.version_info >= (3, 11): - from typing import Unpack, NotRequired, TypedDict + from typing import NotRequired, Unpack else: - from typing_extensions import ( - TypedDict as _TypedDict, - Unpack, - NotRequired, - ) + from typing_extensions import NotRequired, Unpack - # NOTE: workaround for Python < 3.11 - # https://github.com/python/cpython/issues/89026 - if TYPE_CHECKING: +from typing import TypedDict as _TypedDict - class TypedDict(Generic[UserType], _TypedDict): ... +# NOTE: workaround for generic TypedDict support +# https://github.com/python/cpython/issues/89026 +if TYPE_CHECKING: - else: + class TypedDict(Generic[UserType], _TypedDict): ... - class TypedDict(Generic[UserType]): ... +else: + + class TypedDict(Generic[UserType]): ... class GetAuthorizationCodeArgs(TypedDict[UserType]): From d4a7ae3d70a1890b1a5547ccee7d4a7c6080d635 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Fri, 22 Nov 2024 23:32:12 +0400 Subject: [PATCH 24/57] fix: passing the client_id to revoke_token --- aioauth/grant_type.py | 1 + aioauth/server.py | 1 + aioauth/storage.py | 1 + tests/classes.py | 2 +- tests/test_db.py | 1 + 5 files changed, 5 insertions(+), 1 deletion(-) diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 9b9bb6b..996ce49 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -221,6 +221,7 @@ async def create_token_response( # Revoke old token await self.storage.revoke_token( request=request, + client_id=client.client_id, refresh_token=old_token.refresh_token, token_type="refresh_token", access_token=None, diff --git a/aioauth/server.py b/aioauth/server.py index c6e89b0..2879ee1 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -541,6 +541,7 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: if token: await self.storage.revoke_token( request=request, + client_id=client_id, access_token=access_token, refresh_token=refresh_token, token_type=request.post.token_type_hint, diff --git a/aioauth/storage.py b/aioauth/storage.py index 4b779a4..b778ce2 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -88,6 +88,7 @@ class GetTokenArgs(TypedDict[UserType]): class RevokeTokenArgs(TypedDict[UserType]): request: Request[UserType] + client_id: str refresh_token: Optional[str] token_type: Optional[TokenType] access_token: Optional[str] diff --git a/tests/classes.py b/tests/classes.py index 5cf1fb0..3495380 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -61,7 +61,7 @@ def _get_by_client_id(self, client_id: str): async def get_client( self, **kwargs: Unpack[GetClientArgs[User]], - ) -> Optional[Client]: + ) -> Optional[Client[User]]: client_secret = kwargs.get("client_secret") client_id = kwargs["client_id"] diff --git a/tests/test_db.py b/tests/test_db.py index 355eff7..d4e5940 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -64,6 +64,7 @@ async def test_storage_class() -> None: with pytest.raises(NotImplementedError): await db.revoke_token( request=request, + client_id=client.client_id, refresh_token=token.refresh_token, token_type=None, access_token=None, From 41fca276bad8667cc3e068dfb0c397e97815a2ad Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Thu, 21 Nov 2024 14:21:49 -0700 Subject: [PATCH 25/57] feat: simplified fastapi example --- examples/fastapi_example.py | 103 +++++++++++++++++ examples/requirements.txt | 5 + examples/store/__init__.py | 59 ++++++++++ examples/store/models.py | 61 ++++++++++ examples/store/storage.py | 215 ++++++++++++++++++++++++++++++++++++ 5 files changed, 443 insertions(+) create mode 100644 examples/fastapi_example.py create mode 100644 examples/requirements.txt create mode 100644 examples/store/__init__.py create mode 100644 examples/store/models.py create mode 100644 examples/store/storage.py diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py new file mode 100644 index 0000000..4ceb2e1 --- /dev/null +++ b/examples/fastapi_example.py @@ -0,0 +1,103 @@ +""" +Bare Minimum Example of FastAPI Implementation of AioAuth + +(Supports AuthCode/Token/RefreshToken ONLY) +""" +import json +from http import HTTPStatus +from typing import cast + +from fastapi import FastAPI, Request, Depends, Response +from fastapi.responses import RedirectResponse +from fastapi_extras.session import SessionMiddleware +from sqlmodel.ext.asyncio.session import AsyncSession + +from aioauth.collections import HTTPHeaderDict +from aioauth.config import Settings +from aioauth.requests import Post, Query +from aioauth.requests import Request as OAuthRequest +from aioauth.responses import Response as OAuthResponse +from aioauth.types import RequestMethod + +from store import AuthServer, BackendStore, engine, auto_login, lifespan + +app = FastAPI(lifespan=lifespan) +settings = Settings(INSECURE_TRANSPORT=True) + +app.add_middleware(SessionMiddleware) + +async def get_auth_server() -> AuthServer: + """ + initialize oauth authorization server + """ + session = AsyncSession(engine) + storage = BackendStore(session) + return AuthServer(storage) + +async def to_request(request: Request) -> OAuthRequest: + """ + convert fastapi request to aioauth oauth2 request + """ + user = request.session.get('user', None) + form = await request.form() + return OAuthRequest( + headers=HTTPHeaderDict(**request.headers), + method=cast(RequestMethod, request.method), + post=Post(**form), #type: ignore + query=Query(**request.query_params), #type: ignore + settings=settings, + url=str(request.url), + user=user, + ) + +def to_response(response: OAuthResponse) -> Response: + """ + convert aioauth oauth2 response into fastapi response + """ + return Response( + content=json.dumps(response.content), + headers=dict(response.headers), + status_code=response.status_code + ) + +@app.get('/oauth/authorize') +async def authorize( + request: Request, + oauth: AuthServer = Depends(get_auth_server) +) -> Response: + """ + oauth2 authorization endpoint using aioauth + """ + oauthreq = await to_request(request) + response = await oauth.create_authorization_response(oauthreq) + if response.status_code == HTTPStatus.UNAUTHORIZED: + request.session['oauth'] = oauthreq + return RedirectResponse('/login') + return to_response(response) + +@app.post('/oauth/tokenize') +async def tokenize( + request: Request, + oauth: AuthServer = Depends(get_auth_server), +): + """ + oauth2 tokenization endpoint using aioauth + """ + oauthreq = await to_request(request) + response = await oauth.create_token_response(oauthreq) + return to_response(response) + +@app.get('/login') +async def login( + request: Request, + oauth: AuthServer = Depends(get_auth_server) +): + """ + barebones "login" page, redirected to when authorize is called before login + """ + # sign in user + oauthreq = request.session['oauth'] + oauthreq.user = await auto_login() + # process authorize request + response = await oauth.create_authorization_response(oauthreq) + return to_response(response) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..3fb198f --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,5 @@ +aioauth==2.0.0 +fastapi==0.115.5 +fastapi_extras3==0.2.0 +SQLAlchemy==2.0.36 +sqlmodel==0.0.22 diff --git a/examples/store/__init__.py b/examples/store/__init__.py new file mode 100644 index 0000000..6845ee0 --- /dev/null +++ b/examples/store/__init__.py @@ -0,0 +1,59 @@ +""" +Utilis and Implementation for AioAuth Storage Interfaces +""" +from contextlib import asynccontextmanager + +from aioauth.server import AuthorizationServer + +from sqlmodel import SQLModel, select +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from .models import Client +from .storage import BackendStore, User + +#** Variables **# +__all__ = ['AuthServer', 'BackendStore', 'engine', 'auto_login', 'lifespan'] + +engine: AsyncEngine = create_async_engine('sqlite+aiosqlite:///:memory:', echo=False, future=True) + +async def auto_login() -> User: + """ + return test user-account simulating login + """ + async with AsyncSession(engine) as conn: + sql = select(User).where(User.username == 'test') + return (await conn.exec(sql)).one() + +@asynccontextmanager +async def lifespan(*_): + """ + async database startup/shutdown context-manager + """ + global oauth + # spawn connection pool and ensure tables are made + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + # create test records + async with AsyncSession(engine) as session: + user = User( + username='test', + password='password', + ) + client = Client( + client_id='test_client', + client_secret='password', + grant_types='authorization_code,refresh_token', + redirect_uris='http://localhost:3000/redirect', + response_types='code', + scope='email' + ) + session.add(user) + session.add(client) + await session.commit() + yield + # close connections on app closure + await engine.dispose() + +class AuthServer(AuthorizationServer[User]): + pass diff --git a/examples/store/models.py b/examples/store/models.py new file mode 100644 index 0000000..8b9e951 --- /dev/null +++ b/examples/store/models.py @@ -0,0 +1,61 @@ +""" +Database Models for OAuth2 Data Storage +""" +from typing import Optional, List +from sqlmodel import Field, SQLModel, Relationship + +class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + username: str = Field(unique=True, index=True) + password: Optional[str] = None + + user_clients: List['Client'] = Relationship(back_populates='user') + user_auth_codes: List['AuthorizationCode'] = Relationship(back_populates='user') + user_tokens: List['Token'] = Relationship(back_populates='user') + +class Client(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + client_id: str = Field(unique=True, index=True) + client_secret: Optional[str] + grant_types: str + response_types: str + redirect_uris: str + scope: str + + user_id: Optional[int] = Field(default=None, foreign_key='user.id') + user: User = Relationship(back_populates='user_clients') + +class AuthorizationCode(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + code: str + client_id: str + redirect_uri: str + response_type: str + scope: str + auth_time: int + expires_in: int + code_challenge: Optional[str] + code_challenge_method: Optional[str] + nonce: Optional[str] + + user_id: Optional[int] = Field(default=None, foreign_key='user.id') + user: User = Relationship(back_populates='user_auth_codes') + +class Token(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + access_token: str + refresh_token: str + scope: str + issued_at: int + expires_in: int + refresh_token_expires_in: int + client_id: str + token_type: str + revoked: bool + + user_id: Optional[int] = Field(default=None, foreign_key='user.id') + user: User = Relationship(back_populates='user_tokens') diff --git a/examples/store/storage.py b/examples/store/storage.py new file mode 100644 index 0000000..eeac6d3 --- /dev/null +++ b/examples/store/storage.py @@ -0,0 +1,215 @@ +""" +Storage Interface Implementations for AioOAuth using SqlModels for Backend +""" +from datetime import datetime, timezone +from aioauth.storage import * +from aioauth.types import ResponseType + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from .models import User +from .models import Client as ClientTable +from .models import AuthorizationCode as AuthCodeTable +from .models import Token as TokenTable + +#** Classes **# + +class ClientStore(ClientStorage[User]): + + def __init__(self, session: AsyncSession): + self.session = session + + async def get_client(self, + request: Request[User], + client_id: str, + client_secret: Optional[str] = None + ) -> Optional[Client[User]]: + """ + """ + sql = select(ClientTable).where(ClientTable.client_id == client_id) + async with self.session: + record = (await self.session.exec(sql)).one_or_none() + if record is None: + return + if client_secret is not None and record.client_secret is not None: + if client_secret != record.client_secret: + return + return Client( + client_id=record.client_id, + client_secret=record.client_secret or '', + grant_types=record.grant_types.split(','), #type: ignore + response_types=record.response_types.split(','), #type: ignore + redirect_uris=record.redirect_uris.split(','), + scope=record.scope, + ) + +class AuthCodeStore(AuthorizationCodeStorage[User]): + + def __init__(self, session: AsyncSession): + self.session = session + + async def create_authorization_code(self, + request: Request[User], + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + code_challenge_method: Optional[CodeChallengeMethod], + code_challenge: Optional[str], + code: str, + **kwargs + ) -> AuthorizationCode: + """""" + auth_code = AuthorizationCode( + code=code, + client_id=client_id, + redirect_uri=redirect_uri, + response_type=response_type, + scope=scope, + auth_time=int(datetime.now(tz=timezone.utc).timestamp()), + expires_in=300, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + user=request.user, + **kwargs, + ) + record = AuthCodeTable( + code=auth_code.code, + client_id=auth_code.client_id, + redirect_uri=auth_code.redirect_uri, + response_type=auth_code.response_type, + scope=auth_code.scope, + auth_time=auth_code.auth_time, + expires_in=auth_code.expires_in, + code_challenge=auth_code.code_challenge, + code_challenge_method=auth_code.code_challenge_method, + nonce=auth_code.nonce, + user_id=request.user.id if request.user else None, + ) + async with self.session: + self.session.add(record) + await self.session.commit() + return auth_code + + async def get_authorization_code(self, + request: Request[User], + client_id: str, + code: str + ) -> Optional[AuthorizationCode]: + """ + """ + async with self.session: + sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + result = (await self.session.exec(sql)).one_or_none() + if result is not None: + return AuthorizationCode( + code=result.code, + client_id=result.client_id, + redirect_uri=result.redirect_uri, + response_type=result.response_type, + scope=result.scope, + auth_time=result.auth_time, + expires_in=result.expires_in, + code_challenge=result.code_challenge, + code_challenge_method=result.code_challenge_method, #type: ignore + nonce=result.nonce, + ) + + async def delete_authorization_code(self, + request: Request[User], client_id: str, code: str) -> None: + """ + """ + async with self.session: + sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + result = (await self.session.exec(sql)).one() + await self.session.delete(result) + await self.session.commit() + +class TokenStore(TokenStorage[User]): + + def __init__(self, session: AsyncSession): + self.session = session + + async def create_token(self, + request: Request[User], + client_id: str, + scope: str, + access_token: str, + refresh_token: str + ) -> Token: + """ + """ + token = Token( + client_id=client_id, + access_token=access_token, + refresh_token=refresh_token, + scope=scope, + issued_at=int(datetime.now(tz=timezone.utc).timestamp()), + expires_in=300, + refresh_token_expires_in=900, + user=request.user, + ) + record = TokenTable( + client_id=token.client_id, + access_token=token.access_token, + refresh_token=token.refresh_token, + scope=token.scope, + issued_at=token.issued_at, + expires_in=token.expires_in, + refresh_token_expires_in=token.refresh_token_expires_in, + token_type=token.token_type, + revoked=token.revoked, + user_id=token.user.id if token.user else None, + ) + async with self.session: + self.session.add(record) + await self.session.commit() + return token + + async def get_token(self, + request: Request[User], + client_id: str, + token_type: Optional[TokenType] = 'refresh_token', + access_token: Optional[str] = None, + refresh_token: Optional[str] = None + ) -> Optional[Token]: + """ + """ + sql = select(TokenTable) + sql = sql.where(TokenTable.refresh_token == refresh_token) \ + if token_type == 'refresh_token' else \ + sql.where(TokenTable.access_token == access_token) + async with self.session: + result = (await self.session.exec(sql)).one_or_none() + if result is not None: + return Token( + client_id=result.client_id, + access_token=result.access_token, + refresh_token=result.refresh_token, + scope=result.scope, + issued_at=result.issued_at, + expires_in=result.expires_in, + refresh_token_expires_in=result.refresh_token_expires_in, + user=result.user, + ) + + async def revoke_token(self, + request: Request[User], + token_type: Optional[TokenType] = 'refresh_token', + access_token: Optional[str] = None, + refresh_token: Optional[str] = None + ) -> None: + """ + """ + sql = select(TokenTable) + sql = sql.where(TokenTable.refresh_token == refresh_token) \ + if token_type == 'refresh_token' else \ + sql.where(TokenTable.access_token == access_token) + async with self.session: + result = (await self.session.exec(sql)).one() + await self.session.delete(result) + await self.session.commit() + +class BackendStore(ClientStore, AuthCodeStore, TokenStore, BaseStorage[User]): + pass From 4f44bc6d2da1b3bf098595b79932b38edf6c0aca Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Mon, 25 Nov 2024 15:20:54 -0700 Subject: [PATCH 26/57] feat: formatting example and minor tweaks --- examples/README.md | 11 ++ examples/config.json | 21 +++ .../example.py} | 59 ++++--- examples/fastapi/requirements.txt | 5 + examples/pyproject.toml | 29 ++++ examples/requirements.txt | 5 - examples/{store => shared}/__init__.py | 46 ++++-- examples/shared/config.py | 27 ++++ examples/shared/models.py | 66 ++++++++ examples/{store => shared}/storage.py | 145 ++++++++++-------- examples/store/models.py | 61 -------- 11 files changed, 310 insertions(+), 165 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/config.json rename examples/{fastapi_example.py => fastapi/example.py} (67%) create mode 100644 examples/fastapi/requirements.txt create mode 100644 examples/pyproject.toml delete mode 100644 examples/requirements.txt rename examples/{store => shared}/__init__.py (57%) create mode 100644 examples/shared/config.py create mode 100644 examples/shared/models.py rename examples/{store => shared}/storage.py (67%) delete mode 100644 examples/store/models.py diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..94ff375 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,11 @@ +## Aioauth Examples + +### FastAPI Example + +Usage: + +```bash +$ cd fastapi +$ pip install -r requirements.txt +$ python3 example.py +``` diff --git a/examples/config.json b/examples/config.json new file mode 100644 index 0000000..a52825e --- /dev/null +++ b/examples/config.json @@ -0,0 +1,21 @@ +{ + "fixtures": { + "users": [ + {"username": "test", "password": "password"} + ], + "clients": [ + { + "client_id": "test_client", + "client_secret": "password", + "grant_types": "authorization_code,refresh_token", + "redirect_uris": "http://localhost:3000/redirect", + "response_types": "code", + "scope": "email" + } + ] + }, + "settings": { + "DEBUG": true, + "INSECURE_TRANSPORT": true + } +} diff --git a/examples/fastapi_example.py b/examples/fastapi/example.py similarity index 67% rename from examples/fastapi_example.py rename to examples/fastapi/example.py index 4ceb2e1..c3d05b5 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi/example.py @@ -3,6 +3,17 @@ (Supports AuthCode/Token/RefreshToken ONLY) """ + +import os +import sys + +BASE_DIR = os.path.dirname(__file__) +EXAMPLES_DIR = os.path.relpath(os.path.join(BASE_DIR, "../")) +AIOAUTH_DIR = os.path.relpath(os.path.join(BASE_DIR, "../../")) + +sys.path.insert(0, EXAMPLES_DIR) # to import `aioauth/shared` +sys.path.insert(0, AIOAUTH_DIR) # to import `aioauth/aioauth` + import json from http import HTTPStatus from typing import cast @@ -13,19 +24,18 @@ from sqlmodel.ext.asyncio.session import AsyncSession from aioauth.collections import HTTPHeaderDict -from aioauth.config import Settings from aioauth.requests import Post, Query from aioauth.requests import Request as OAuthRequest from aioauth.responses import Response as OAuthResponse from aioauth.types import RequestMethod -from store import AuthServer, BackendStore, engine, auto_login, lifespan +from shared import AuthServer, BackendStore, engine, settings, auto_login, lifespan -app = FastAPI(lifespan=lifespan) -settings = Settings(INSECURE_TRANSPORT=True) +app = FastAPI(lifespan=lifespan) app.add_middleware(SessionMiddleware) + async def get_auth_server() -> AuthServer: """ initialize oauth authorization server @@ -34,22 +44,24 @@ async def get_auth_server() -> AuthServer: storage = BackendStore(session) return AuthServer(storage) + async def to_request(request: Request) -> OAuthRequest: """ convert fastapi request to aioauth oauth2 request """ - user = request.session.get('user', None) + user = request.session.get("user", None) form = await request.form() return OAuthRequest( headers=HTTPHeaderDict(**request.headers), method=cast(RequestMethod, request.method), - post=Post(**form), #type: ignore - query=Query(**request.query_params), #type: ignore + post=Post(**form), # type: ignore + query=Query(**request.query_params), # type: ignore settings=settings, url=str(request.url), user=user, ) + def to_response(response: OAuthResponse) -> Response: """ convert aioauth oauth2 response into fastapi response @@ -57,13 +69,13 @@ def to_response(response: OAuthResponse) -> Response: return Response( content=json.dumps(response.content), headers=dict(response.headers), - status_code=response.status_code + status_code=response.status_code, ) -@app.get('/oauth/authorize') + +@app.get("/oauth/authorize") async def authorize( - request: Request, - oauth: AuthServer = Depends(get_auth_server) + request: Request, oauth: AuthServer = Depends(get_auth_server) ) -> Response: """ oauth2 authorization endpoint using aioauth @@ -71,14 +83,15 @@ async def authorize( oauthreq = await to_request(request) response = await oauth.create_authorization_response(oauthreq) if response.status_code == HTTPStatus.UNAUTHORIZED: - request.session['oauth'] = oauthreq - return RedirectResponse('/login') + request.session["oauth"] = oauthreq + return RedirectResponse("/login") return to_response(response) -@app.post('/oauth/tokenize') + +@app.post("/oauth/tokenize") async def tokenize( request: Request, - oauth: AuthServer = Depends(get_auth_server), + oauth: AuthServer = Depends(get_auth_server), ): """ oauth2 tokenization endpoint using aioauth @@ -87,17 +100,21 @@ async def tokenize( response = await oauth.create_token_response(oauthreq) return to_response(response) -@app.get('/login') -async def login( - request: Request, - oauth: AuthServer = Depends(get_auth_server) -): + +@app.get("/login") +async def login(request: Request, oauth: AuthServer = Depends(get_auth_server)): """ barebones "login" page, redirected to when authorize is called before login """ # sign in user - oauthreq = request.session['oauth'] + oauthreq = request.session["oauth"] oauthreq.user = await auto_login() # process authorize request response = await oauth.create_authorization_response(oauthreq) return to_response(response) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app) diff --git a/examples/fastapi/requirements.txt b/examples/fastapi/requirements.txt new file mode 100644 index 0000000..36890d5 --- /dev/null +++ b/examples/fastapi/requirements.txt @@ -0,0 +1,5 @@ +uvicorn==0.32.1 +fastapi==0.115.5 +fastapi_extras3==0.3.0 +sqlmodel==0.0.22 +aiosqlite==0.20.0 diff --git a/examples/pyproject.toml b/examples/pyproject.toml new file mode 100644 index 0000000..4fc361d --- /dev/null +++ b/examples/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ['setuptools', 'setuptools-scm'] +build-backend = 'setuptools.build_meta' + +[project] +name = 'aioauth_examples' +version = '0.1.0' +requires-python = '>=3.9' +dependencies = [ + 'sqlmodel>=0.0.22', + 'aiosqlite>=0.20.0' +] +authors = [ + {name = 'Andrew Scott', email = 'imgurbot12@gmail.com'}, +] +description = 'Simple Aioauth Project Examples.' +classifiers = [ + 'Framework :: FastAPI', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', +] diff --git a/examples/requirements.txt b/examples/requirements.txt deleted file mode 100644 index 3fb198f..0000000 --- a/examples/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -aioauth==2.0.0 -fastapi==0.115.5 -fastapi_extras3==0.2.0 -SQLAlchemy==2.0.36 -sqlmodel==0.0.22 diff --git a/examples/store/__init__.py b/examples/shared/__init__.py similarity index 57% rename from examples/store/__init__.py rename to examples/shared/__init__.py index 6845ee0..a57334d 100644 --- a/examples/store/__init__.py +++ b/examples/shared/__init__.py @@ -1,30 +1,49 @@ """ -Utilis and Implementation for AioAuth Storage Interfaces +Shared Utilites and Implementation for AioAuth Storage Interfaces """ + from contextlib import asynccontextmanager +import os from aioauth.server import AuthorizationServer from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from .config import load_config from .models import Client from .storage import BackendStore, User -#** Variables **# -__all__ = ['AuthServer', 'BackendStore', 'engine', 'auto_login', 'lifespan'] +__all__ = [ + "AuthServer", + "BackendStore", + "engine", + "config", + "settings", + "auto_login", + "lifespan", +] + +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../config.json") + +engine: AsyncEngine = create_async_engine( + "sqlite+aiosqlite:///:memory:", echo=False, future=True +) + +config = load_config(CONFIG_PATH) +settings = config.settings -engine: AsyncEngine = create_async_engine('sqlite+aiosqlite:///:memory:', echo=False, future=True) async def auto_login() -> User: """ return test user-account simulating login """ async with AsyncSession(engine) as conn: - sql = select(User).where(User.username == 'test') + sql = select(User).where(User.username == "test") return (await conn.exec(sql)).one() + @asynccontextmanager async def lifespan(*_): """ @@ -37,16 +56,16 @@ async def lifespan(*_): # create test records async with AsyncSession(engine) as session: user = User( - username='test', - password='password', + username="test", + password="password", ) client = Client( - client_id='test_client', - client_secret='password', - grant_types='authorization_code,refresh_token', - redirect_uris='http://localhost:3000/redirect', - response_types='code', - scope='email' + client_id="test_client", + client_secret="password", + grant_types="authorization_code,refresh_token", + redirect_uris="http://localhost:3000/redirect", + response_types="code", + scope="email", ) session.add(user) session.add(client) @@ -55,5 +74,6 @@ async def lifespan(*_): # close connections on app closure await engine.dispose() + class AuthServer(AuthorizationServer[User]): pass diff --git a/examples/shared/config.py b/examples/shared/config.py new file mode 100644 index 0000000..ccf89b6 --- /dev/null +++ b/examples/shared/config.py @@ -0,0 +1,27 @@ +""" +Global Example Configuration Settings +""" + +from typing import List + +from pydantic import BaseModel +from aioauth.config import Settings + +from .models import User, Client + + +def load_config(fpath: str) -> "Config": + """load configuration from filepath""" + with open(fpath, "r") as f: + json = f.read() + return Config.model_validate_json(json) + + +class Fixtures(BaseModel): + users: List[User] + clients: List[Client] + + +class Config(BaseModel): + fixtures: Fixtures + settings: Settings diff --git a/examples/shared/models.py b/examples/shared/models.py new file mode 100644 index 0000000..f854065 --- /dev/null +++ b/examples/shared/models.py @@ -0,0 +1,66 @@ +""" +Database Models for OAuth2 Data Storage +""" + +from typing import Optional, List +from sqlmodel import Field, SQLModel, Relationship + + +class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + username: str = Field(unique=True, index=True) + password: Optional[str] = None + + user_clients: List["Client"] = Relationship(back_populates="user") + user_auth_codes: List["AuthorizationCode"] = Relationship(back_populates="user") + user_tokens: List["Token"] = Relationship(back_populates="user") + + +class Client(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + client_id: str = Field(unique=True, index=True) + client_secret: Optional[str] + grant_types: str + response_types: str + redirect_uris: str + scope: str + + user_id: Optional[int] = Field(default=None, foreign_key="user.id") + user: User = Relationship(back_populates="user_clients") + + +class AuthorizationCode(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + code: str + client_id: str + redirect_uri: str + response_type: str + scope: str + auth_time: int + expires_in: int + code_challenge: Optional[str] + code_challenge_method: Optional[str] + nonce: Optional[str] + + user_id: Optional[int] = Field(default=None, foreign_key="user.id") + user: User = Relationship(back_populates="user_auth_codes") + + +class Token(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + access_token: str + refresh_token: str + scope: str + issued_at: int + expires_in: int + refresh_token_expires_in: int + client_id: str + token_type: str + revoked: bool + + user_id: Optional[int] = Field(default=None, foreign_key="user.id") + user: User = Relationship(back_populates="user_tokens") diff --git a/examples/store/storage.py b/examples/shared/storage.py similarity index 67% rename from examples/store/storage.py rename to examples/shared/storage.py index eeac6d3..ffe28c8 100644 --- a/examples/store/storage.py +++ b/examples/shared/storage.py @@ -1,64 +1,75 @@ """ Storage Interface Implementations for AioOAuth using SqlModels for Backend """ + from datetime import datetime, timezone -from aioauth.storage import * -from aioauth.types import ResponseType +from typing import Optional from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from aioauth.models import AuthorizationCode, Client, Token +from aioauth.requests import Request +from aioauth.storage import ( + BaseStorage, + ClientStorage, + AuthorizationCodeStorage, + TokenStorage, +) +from aioauth.types import CodeChallengeMethod, TokenType + from .models import User from .models import Client as ClientTable from .models import AuthorizationCode as AuthCodeTable from .models import Token as TokenTable -#** Classes **# class ClientStore(ClientStorage[User]): def __init__(self, session: AsyncSession): self.session = session - async def get_client(self, - request: Request[User], - client_id: str, - client_secret: Optional[str] = None + async def get_client( + self, + request: Request[User], + client_id: str, + client_secret: Optional[str] = None, ) -> Optional[Client[User]]: - """ - """ + """ """ sql = select(ClientTable).where(ClientTable.client_id == client_id) async with self.session: record = (await self.session.exec(sql)).one_or_none() if record is None: - return + return None if client_secret is not None and record.client_secret is not None: if client_secret != record.client_secret: - return + return None return Client( client_id=record.client_id, - client_secret=record.client_secret or '', - grant_types=record.grant_types.split(','), #type: ignore - response_types=record.response_types.split(','), #type: ignore - redirect_uris=record.redirect_uris.split(','), + client_secret=record.client_secret or "", + grant_types=record.grant_types.split(","), # type: ignore + response_types=record.response_types.split(","), # type: ignore + redirect_uris=record.redirect_uris.split(","), scope=record.scope, ) + class AuthCodeStore(AuthorizationCodeStorage[User]): def __init__(self, session: AsyncSession): self.session = session - async def create_authorization_code(self, - request: Request[User], - client_id: str, - scope: str, - response_type: ResponseType, - redirect_uri: str, + async def create_authorization_code( + self, + request: Request[User], + client_id: str, + scope: str, + response_type: str, + redirect_uri: str, code_challenge_method: Optional[CodeChallengeMethod], - code_challenge: Optional[str], + code_challenge: Optional[str], code: str, - **kwargs + **kwargs, ) -> AuthorizationCode: """""" auth_code = AuthorizationCode( @@ -92,15 +103,12 @@ async def create_authorization_code(self, await self.session.commit() return auth_code - async def get_authorization_code(self, - request: Request[User], - client_id: str, - code: str + async def get_authorization_code( + self, request: Request[User], client_id: str, code: str ) -> Optional[AuthorizationCode]: - """ - """ + """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) result = (await self.session.exec(sql)).one_or_none() if result is not None: return AuthorizationCode( @@ -112,34 +120,35 @@ async def get_authorization_code(self, auth_time=result.auth_time, expires_in=result.expires_in, code_challenge=result.code_challenge, - code_challenge_method=result.code_challenge_method, #type: ignore + code_challenge_method=result.code_challenge_method, # type: ignore nonce=result.nonce, ) - async def delete_authorization_code(self, - request: Request[User], client_id: str, code: str) -> None: - """ - """ + async def delete_authorization_code( + self, request: Request[User], client_id: str, code: str + ) -> None: + """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) result = (await self.session.exec(sql)).one() await self.session.delete(result) await self.session.commit() + class TokenStore(TokenStorage[User]): def __init__(self, session: AsyncSession): self.session = session - async def create_token(self, - request: Request[User], - client_id: str, - scope: str, - access_token: str, - refresh_token: str + async def create_token( + self, + request: Request[User], + client_id: str, + scope: str, + access_token: str, + refresh_token: str, ) -> Token: - """ - """ + """ """ token = Token( client_id=client_id, access_token=access_token, @@ -167,19 +176,21 @@ async def create_token(self, await self.session.commit() return token - async def get_token(self, - request: Request[User], - client_id: str, - token_type: Optional[TokenType] = 'refresh_token', - access_token: Optional[str] = None, - refresh_token: Optional[str] = None + async def get_token( + self, + request: Request[User], + client_id: str, + token_type: Optional[TokenType] = "refresh_token", + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, ) -> Optional[Token]: - """ - """ + """ """ sql = select(TokenTable) - sql = sql.where(TokenTable.refresh_token == refresh_token) \ - if token_type == 'refresh_token' else \ - sql.where(TokenTable.access_token == access_token) + sql = ( + sql.where(TokenTable.refresh_token == refresh_token) + if token_type == "refresh_token" + else sql.where(TokenTable.access_token == access_token) + ) async with self.session: result = (await self.session.exec(sql)).one_or_none() if result is not None: @@ -194,22 +205,26 @@ async def get_token(self, user=result.user, ) - async def revoke_token(self, - request: Request[User], - token_type: Optional[TokenType] = 'refresh_token', - access_token: Optional[str] = None, - refresh_token: Optional[str] = None + async def revoke_token( + self, + request: Request[User], + client_id: str, + token_type: Optional[TokenType] = "refresh_token", + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, ) -> None: - """ - """ + """ """ sql = select(TokenTable) - sql = sql.where(TokenTable.refresh_token == refresh_token) \ - if token_type == 'refresh_token' else \ - sql.where(TokenTable.access_token == access_token) + sql = ( + sql.where(TokenTable.refresh_token == refresh_token) + if token_type == "refresh_token" + else sql.where(TokenTable.access_token == access_token) + ) async with self.session: result = (await self.session.exec(sql)).one() await self.session.delete(result) await self.session.commit() + class BackendStore(ClientStore, AuthCodeStore, TokenStore, BaseStorage[User]): pass diff --git a/examples/store/models.py b/examples/store/models.py deleted file mode 100644 index 8b9e951..0000000 --- a/examples/store/models.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Database Models for OAuth2 Data Storage -""" -from typing import Optional, List -from sqlmodel import Field, SQLModel, Relationship - -class User(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - username: str = Field(unique=True, index=True) - password: Optional[str] = None - - user_clients: List['Client'] = Relationship(back_populates='user') - user_auth_codes: List['AuthorizationCode'] = Relationship(back_populates='user') - user_tokens: List['Token'] = Relationship(back_populates='user') - -class Client(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - client_id: str = Field(unique=True, index=True) - client_secret: Optional[str] - grant_types: str - response_types: str - redirect_uris: str - scope: str - - user_id: Optional[int] = Field(default=None, foreign_key='user.id') - user: User = Relationship(back_populates='user_clients') - -class AuthorizationCode(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - code: str - client_id: str - redirect_uri: str - response_type: str - scope: str - auth_time: int - expires_in: int - code_challenge: Optional[str] - code_challenge_method: Optional[str] - nonce: Optional[str] - - user_id: Optional[int] = Field(default=None, foreign_key='user.id') - user: User = Relationship(back_populates='user_auth_codes') - -class Token(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - access_token: str - refresh_token: str - scope: str - issued_at: int - expires_in: int - refresh_token_expires_in: int - client_id: str - token_type: str - revoked: bool - - user_id: Optional[int] = Field(default=None, foreign_key='user.id') - user: User = Relationship(back_populates='user_tokens') From 6c4a8c405783b9e2afc1cf50df45bbe40d154401 Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Thu, 5 Dec 2024 13:45:39 -0700 Subject: [PATCH 27/57] chore: reorganize files with flat fastapi-example and requirements --- examples/{fastapi/example.py => fastapi_example.py} | 11 ----------- examples/{fastapi => }/requirements.txt | 0 2 files changed, 11 deletions(-) rename examples/{fastapi/example.py => fastapi_example.py} (90%) rename examples/{fastapi => }/requirements.txt (100%) diff --git a/examples/fastapi/example.py b/examples/fastapi_example.py similarity index 90% rename from examples/fastapi/example.py rename to examples/fastapi_example.py index c3d05b5..1387cc3 100644 --- a/examples/fastapi/example.py +++ b/examples/fastapi_example.py @@ -3,17 +3,6 @@ (Supports AuthCode/Token/RefreshToken ONLY) """ - -import os -import sys - -BASE_DIR = os.path.dirname(__file__) -EXAMPLES_DIR = os.path.relpath(os.path.join(BASE_DIR, "../")) -AIOAUTH_DIR = os.path.relpath(os.path.join(BASE_DIR, "../../")) - -sys.path.insert(0, EXAMPLES_DIR) # to import `aioauth/shared` -sys.path.insert(0, AIOAUTH_DIR) # to import `aioauth/aioauth` - import json from http import HTTPStatus from typing import cast diff --git a/examples/fastapi/requirements.txt b/examples/requirements.txt similarity index 100% rename from examples/fastapi/requirements.txt rename to examples/requirements.txt From b7ebb68ab56c07ce40528d2b87b83727a915b19a Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Fri, 6 Dec 2024 09:35:48 -0700 Subject: [PATCH 28/57] chore: remove unnecessary pyproject.toml --- examples/pyproject.toml | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 examples/pyproject.toml diff --git a/examples/pyproject.toml b/examples/pyproject.toml deleted file mode 100644 index 4fc361d..0000000 --- a/examples/pyproject.toml +++ /dev/null @@ -1,29 +0,0 @@ -[build-system] -requires = ['setuptools', 'setuptools-scm'] -build-backend = 'setuptools.build_meta' - -[project] -name = 'aioauth_examples' -version = '0.1.0' -requires-python = '>=3.9' -dependencies = [ - 'sqlmodel>=0.0.22', - 'aiosqlite>=0.20.0' -] -authors = [ - {name = 'Andrew Scott', email = 'imgurbot12@gmail.com'}, -] -description = 'Simple Aioauth Project Examples.' -classifiers = [ - 'Framework :: FastAPI', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', -] From e5af02ea970234bc1eb70164c4a1b87a9cb0dc71 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 7 Dec 2024 18:07:53 +0400 Subject: [PATCH 29/57] fix: added python-multipart as a requirement for fastapi example --- examples/fastapi_example.py | 1 + examples/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 1387cc3..0582f8a 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -3,6 +3,7 @@ (Supports AuthCode/Token/RefreshToken ONLY) """ + import json from http import HTTPStatus from typing import cast diff --git a/examples/requirements.txt b/examples/requirements.txt index 36890d5..3297e66 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -3,3 +3,4 @@ fastapi==0.115.5 fastapi_extras3==0.3.0 sqlmodel==0.0.22 aiosqlite==0.20.0 +python-multipart==0.0.19 From 172aae181420de84dce04b3c865f4e928a29ff5d Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Sun, 8 Dec 2024 20:41:22 -0700 Subject: [PATCH 30/57] fix: avoid mypy dependency resolution error within example on pre-commit check --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5988e78..5ee7304 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,10 @@ explicit_package_bases = true [tool.mypy-packages] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = 'shared' +ignore_missing_imports = true + [tool.flake8] ignore = ["D10", "E203", "E501", "W503", "D205", "D400", "A001", "D210", "D401", "E701"] max-line-length = 88 From 64769d86fbd84a5cb09f71b14b1a0f4362bb0a8a Mon Sep 17 00:00:00 2001 From: "Jose M. Prieto" Date: Sat, 30 Nov 2024 18:49:23 +0100 Subject: [PATCH 31/57] Optional refresh token in implicit grant Solution proposal for issues: #98. This commit includes additional logic to make the generation of refresh tokens optional in implicit grant flows so that this package becomes strictly compliant with the OAuth 2.0 specification in section 4.2. To see further information, please refer to issue #98. To not break backward compatibility, this new option has been included as an new optional setting flag, `ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT`, that when set to `False`, it deactivates the refresh token generation in implicit flows. --- aioauth/config.py | 27 +++++++++++++ aioauth/response_type.py | 14 ++++++- aioauth/responses.py | 6 +-- aioauth/server.py | 12 +++++- aioauth/storage.py | 2 +- tests/test_flow.py | 84 ++++++++++++++++++++++++++++------------ 6 files changed, 115 insertions(+), 30 deletions(-) diff --git a/aioauth/config.py b/aioauth/config.py index 2874af1..c6e335e 100644 --- a/aioauth/config.py +++ b/aioauth/config.py @@ -21,6 +21,33 @@ class Settings: REFRESH_TOKEN_EXPIRES_IN: int = TOKEN_EXPIRES_IN * 2 """Refresh token lifetime in seconds. Defaults to TOKEN_EXPIRES_IN * 2 (48 hours).""" + ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: bool = True + """Issue refresh tokens during implicit grant dialog. + + Note: + This flag can be used, when sets to `True`, to strictly meet the requirements + described in section 4.2 of the RFC 6749 regarding the issuance of refresh + tokens during grant type "Implicit Grant". In particular, as stated in section + 4.2.2 of that RFC: + + > 4.2.2. Access Token Response + > + > If the resource owner grants the access request, the authorization + > server issues an access token and delivers it to the client by adding + > the following parameters to the fragment component of the redirection + > URI using the "application/x-www-form-urlencoded" format, per + > Appendix B: + > + > [...] + > + > The authorization server MUST NOT issue a refresh token. + + Reference links: + + - https://datatracker.ietf.org/doc/html/rfc6749#section-4.2 + - https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2 + """ + AUTHORIZATION_CODE_EXPIRES_IN: int = 5 * 60 """Authorization code lifetime in seconds. Defaults to 5 minutes.""" diff --git a/aioauth/response_type.py b/aioauth/response_type.py index e2e975a..b23d79e 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -7,6 +7,7 @@ ---- """ + import sys from typing import Generic, Tuple @@ -112,8 +113,19 @@ async def create_authorization_response( client.client_id, request.query.scope, generate_token(42), - generate_token(48), + ( + generate_token(48) + if request.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT + else None + ), ) + if not request.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: + return TokenResponse( + expires_in=token.expires_in, + access_token=token.access_token, + scope=token.scope, + token_type=token.token_type, + ) return TokenResponse( expires_in=token.expires_in, refresh_token_expires_in=token.refresh_token_expires_in, diff --git a/aioauth/responses.py b/aioauth/responses.py index 63c02c7..269b9fc 100644 --- a/aioauth/responses.py +++ b/aioauth/responses.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass, field from http import HTTPStatus -from typing import Dict +from typing import Dict, Optional from .collections import HTTPHeaderDict from .constances import default_headers @@ -52,10 +52,10 @@ class TokenResponse: """ expires_in: int - refresh_token_expires_in: int access_token: str - refresh_token: str scope: str + refresh_token_expires_in: Optional[int] = None + refresh_token: Optional[str] = None token_type: str = "Bearer" diff --git a/aioauth/server.py b/aioauth/server.py index b16f9af..736db0c 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -16,6 +16,7 @@ ---- """ + import sys from dataclasses import asdict from http import HTTPStatus @@ -417,7 +418,16 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: response = await response_type.create_authorization_response( request, client ) - responses.update(asdict(response)) + response_asdict = asdict(response) + if ( + isinstance(response_type, ResponseTypeToken) + and not request.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT + ): + # This is the implicit grant where the generation of refresh token has + # been disabled in settings + response_asdict.pop("refresh_token") + response_asdict.pop("refresh_token_expires_in") + responses.update(response_asdict) # See: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#Combinations if "code" in response_type_list: diff --git a/aioauth/storage.py b/aioauth/storage.py index d6d6c45..8f3ec6e 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -24,7 +24,7 @@ async def create_token( client_id: str, scope: str, access_token: str, - refresh_token: str, + refresh_token: Optional[str] = None, ) -> TToken: """Generates a user token and stores it in the database. diff --git a/tests/test_flow.py b/tests/test_flow.py index 39339d6..0ed1a52 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -3,6 +3,7 @@ import pytest +from aioauth.config import Settings from aioauth.constances import default_headers from aioauth.requests import Post, Query, Request from aioauth.utils import ( @@ -215,9 +216,17 @@ async def test_authorization_code_flow_pkce_code_challenge(): @pytest.mark.asyncio -async def test_implicit_flow(context_factory): +@pytest.mark.parametrize( + ids=["default_settings", "no_issue_refresh_token_implicit"], + argnames="settings", + argvalues=[None, Settings(ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT=False)], +) +async def test_implicit_flow(context_factory, settings): username = "username" - context = context_factory(users={username: "password"}) + context = context_factory( + users={username: "password"}, + settings=settings, + ) server = context.server client = context.clients[0] request_url = "https://localhost" @@ -237,6 +246,7 @@ async def test_implicit_flow(context_factory): query=query, method="GET", user=username, + settings=context.settings, ) response = await server.create_authorization_response(request) @@ -500,9 +510,17 @@ async def test_client_credentials_flow_auth_header(context: AuthorizationContext @pytest.mark.asyncio -async def test_multiple_response_types(context_factory): +@pytest.mark.parametrize( + ids=["default_settings", "no_issue_refresh_token_implicit"], + argnames="settings", + argvalues=[None, Settings(ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT=False)], +) +async def test_multiple_response_types(context_factory, settings): username = "username" - context = context_factory(users={username: "password"}) + context = context_factory( + users={username: "password"}, + settings=Settings(ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT=False), + ) server = context.server client = context.clients[0] request_url = "https://localhost" @@ -520,6 +538,7 @@ async def test_multiple_response_types(context_factory): query=query, method="GET", user=username, + settings=context.settings, ) await check_request_validators(request, server.create_authorization_response) @@ -532,12 +551,16 @@ async def test_multiple_response_types(context_factory): assert "state" in fragment assert "expires_in" in fragment - assert "refresh_token_expires_in" in fragment assert "access_token" in fragment - assert "refresh_token" in fragment assert "scope" in fragment assert "token_type" in fragment assert "code" in fragment + if context.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: + assert "refresh_token_expires_in" in fragment + assert "refresh_token" in fragment + else: + assert "refresh_token_expires_in" not in fragment + assert "refresh_token" not in fragment @pytest.mark.asyncio @@ -576,6 +599,11 @@ async def test_response_type_none(context_factory): @pytest.mark.asyncio +@pytest.mark.parametrize( + ids=["default_settings", "no_issue_refresh_token_implicit"], + argnames="settings", + argvalues=[None, Settings(ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT=False)], +) @pytest.mark.parametrize( "response_mode,", [ @@ -585,9 +613,12 @@ async def test_response_type_none(context_factory): None, ], ) -async def test_response_type_id_token(context_factory, response_mode): +async def test_response_type_id_token(context_factory, response_mode, settings): username = "username" - context = context_factory(users={username: "password"}) + context = context_factory( + users={username: "password"}, + settings=settings, + ) server = context.server client = context.clients[0] request_url = "https://localhost" @@ -607,6 +638,7 @@ async def test_response_type_id_token(context_factory, response_mode): query=query, method="GET", user=username, + settings=context.settings, ) await check_request_validators(request, server.create_authorization_response) @@ -618,43 +650,47 @@ async def test_response_type_id_token(context_factory, response_mode): fragment = dict(parse_qsl(location.fragment)) query = dict(parse_qsl(location.query)) - if response_mode == "fragment": + if response_mode == "fragment" or response_mode is None: assert "state" in fragment assert "expires_in" in fragment - assert "refresh_token_expires_in" in fragment assert "access_token" in fragment - assert "refresh_token" in fragment assert "scope" in fragment assert "token_type" in fragment assert "code" in fragment assert "id_token" in fragment + if context.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: + assert "refresh_token_expires_in" in fragment + assert "refresh_token" in fragment + else: + assert "refresh_token_expires_in" not in fragment + assert "refresh_token" not in fragment elif response_mode == "form_post": assert "state" in response.content assert "expires_in" in response.content - assert "refresh_token_expires_in" in response.content assert "access_token" in response.content - assert "refresh_token" in response.content assert "scope" in response.content assert "token_type" in response.content assert "code" in response.content assert "id_token" in response.content + if context.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: + assert "refresh_token" in response.content + assert "refresh_token_expires_in" in response.content + else: + assert "refresh_token" not in response.content + assert "refresh_token_expires_in" not in response.content elif response_mode == "query": assert "state" in query assert "expires_in" in query - assert "refresh_token_expires_in" in query assert "access_token" in query - assert "refresh_token" in query assert "scope" in query assert "token_type" in query assert "code" in query assert "id_token" in query + if context.settings.ISSUE_REFRESH_TOKEN_IMPLICIT_GRANT: + assert "refresh_token" in query + assert "refresh_token_expires_in" in query + else: + assert "refresh_token" not in query + assert "refresh_token_expires_in" not in query else: - assert "state" in fragment - assert "expires_in" in fragment - assert "refresh_token_expires_in" in fragment - assert "access_token" in fragment - assert "refresh_token" in fragment - assert "scope" in fragment - assert "token_type" in fragment - assert "code" in fragment - assert "id_token" in fragment + raise AssertionError("Unexpected value of response_mode") From fa3c677e82034f12cea89c86ef438daed178ce50 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 21 Dec 2024 23:05:56 +0400 Subject: [PATCH 32/57] fix: for development environment install examples dependencies --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index a4a0516..c6ceac4 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,7 @@ install: clean ## install the package to the active Python's site-packages dev-install: clean ## install the package and test dependencies for local development python -m pip install --upgrade pip pip install -e ."[dev]" + pip install -r examples/requirements.txt pre-commit install docs-install: ## install packages for local documentation. From d451267b747bdae6c1575296db064ce3c4d88362 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 21 Dec 2024 23:11:17 +0400 Subject: [PATCH 33/57] fix: typing issues in examples after merging master branch --- examples/shared/models.py | 2 +- examples/shared/storage.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/shared/models.py b/examples/shared/models.py index f854065..a2e33a7 100644 --- a/examples/shared/models.py +++ b/examples/shared/models.py @@ -53,7 +53,7 @@ class Token(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) access_token: str - refresh_token: str + refresh_token: Optional[str] scope: str issued_at: int expires_in: int diff --git a/examples/shared/storage.py b/examples/shared/storage.py index ffe28c8..efa0f94 100644 --- a/examples/shared/storage.py +++ b/examples/shared/storage.py @@ -142,11 +142,12 @@ def __init__(self, session: AsyncSession): async def create_token( self, + *, request: Request[User], client_id: str, scope: str, access_token: str, - refresh_token: str, + refresh_token: Optional[str], ) -> Token: """ """ token = Token( From 4b0c2b8baafce383bdd581d00baa558a3d8fd93f Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Sat, 21 Dec 2024 19:03:08 -0700 Subject: [PATCH 34/57] chore: rebase with rebase/2.0.0 --- examples/README.md | 73 ++++++++++++++++++- examples/config.json | 4 +- examples/fastapi_example.py | 97 +++++++++++++++++++++++--- examples/screenshots/approve-form.png | Bin 0 -> 3110 bytes examples/screenshots/login-form.png | Bin 0 -> 2882 bytes examples/shared/__init__.py | 33 ++++----- 6 files changed, 176 insertions(+), 31 deletions(-) create mode 100644 examples/screenshots/approve-form.png create mode 100644 examples/screenshots/login-form.png diff --git a/examples/README.md b/examples/README.md index 94ff375..9af9e22 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,5 +7,76 @@ Usage: ```bash $ cd fastapi $ pip install -r requirements.txt -$ python3 example.py +$ python3 fastapi_example.py +``` + +### Testing + +Initialize an `authorization_code` request with the example server. + +``` +http://localhost:8000/oauth/authorize?client_id=test_client&redirect_uri=https%3A%2F%2Fwww.example.com%2Fredirect&response_type=code&state=somestate&scope=email +``` + +The oauth server authenticates the resource owner (via a login form). + +![login-form](./screenshots/login-form.png) + +The oauth server then checks whether the resource owner approves or +denies the client's access request. + +![login-form](./screenshots/approve-form.png) + +The oauth server will then generate a response as a redirect to the +specified `redirect_uri` in the initial request. If there is an error +with the initial client request, or the resource owner denies the +request the server will respond with an error, otherwise it will +return a success with a generated "authorization-code" + +An error response would look something like this: + +``` +https://www.example.com/redirect?error=access_denied&state=somestate +``` + +Whilst a success looks like this: + +``` +https://www.example.com/redirect?state=somestate&code=EJKOGQhY7KcWjNGI2UbCnOrqAGtRiCEJnAYNwYJ8M5&scope=email +``` + +The client can then request an access-token in exchange for the +authorization-code using the servers token endpoint. + +```bash +curl localhost:8000/oauth/tokenize \ + -u 'test_client:password' \ + -d 'grant_type=authorization_code' \ + -d 'code=EJKOGQhY7KcWjNGI2UbCnOrqAGtRiCEJnAYNwYJ8M5'\ + -d 'redirect_uri=https://www.example.com/redirect' +``` + +The server then responds with the associated `access_token`, `refresh_token`, +and its relevant data: + +```json +{ + "expires_in": 300, + "refresh_token_expires_in": 900, + "access_token": "TIQdQv5FCyBoFtoeGt1tAJ37EJdggl8xgSvCVbdjqD", + "refresh_token": "iJD7Yf4SFuSljmXOhyfjfZelc5J0uIe2P4hwGm4wORCDJyrT", + "scope": "email", + "token_type": "Bearer" +} +``` + +The access-token may be replaced/renewed using the specified `refresh_token` +using the `refresh_token` grant type, which returns the same set of data +before with new tokens. + +```bash +curl localhost:8000/oauth/tokenize \ + -u 'test_client:password' \ + -d 'grant_type=refresh_token' \ + -d 'refresh_token=iJD7Yf4SFuSljmXOhyfjfZelc5J0uIe2P4hwGm4wORCDJyrT' ``` diff --git a/examples/config.json b/examples/config.json index a52825e..2dcbc6f 100644 --- a/examples/config.json +++ b/examples/config.json @@ -1,14 +1,14 @@ { "fixtures": { "users": [ - {"username": "test", "password": "password"} + {"username": "admin", "password": "admin"} ], "clients": [ { "client_id": "test_client", "client_secret": "password", "grant_types": "authorization_code,refresh_token", - "redirect_uris": "http://localhost:3000/redirect", + "redirect_uris": "https://www.example.com/redirect", "response_types": "code", "scope": "email" } diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 0582f8a..e3fb5df 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -5,11 +5,12 @@ """ import json +import html from http import HTTPStatus -from typing import cast +from typing import Optional, cast -from fastapi import FastAPI, Request, Depends, Response -from fastapi.responses import RedirectResponse +from fastapi import FastAPI, Form, Request, Depends, Response +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi_extras.session import SessionMiddleware from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,7 +20,7 @@ from aioauth.responses import Response as OAuthResponse from aioauth.types import RequestMethod -from shared import AuthServer, BackendStore, engine, settings, auto_login, lifespan +from shared import AuthServer, BackendStore, engine, settings, try_login, lifespan app = FastAPI(lifespan=lifespan) @@ -92,13 +93,93 @@ async def tokenize( @app.get("/login") -async def login(request: Request, oauth: AuthServer = Depends(get_auth_server)): +async def login(request: Request, error: Optional[str] = None): """ - barebones "login" page, redirected to when authorize is called before login + barebones login page, redirects to approval after completion """ - # sign in user + if "oauth" not in request.session and error is None: + error = "Cannot Login without OAuth Session" + error = html.escape(error) if error else "" # never trust user-input + content = f""" + + +

Login Form

+

{error}

+
+ + + + + + + + + + + + + +
+
+ + + """ + return HTMLResponse(content, status_code=400 if error else 200) + + +@app.post("/login") +async def login_submit( + request: Request, + username: str = Form(), + password: str = Form(), +): + """ + login form submission handler, redirects to approval on success + """ + user = await try_login(username, password) + if user is None: + return await login(request, error="Invalid Username or Password") + request.session["user"] = user + redirect = request.url_for("approve") + return RedirectResponse(redirect, status_code=303) + # # sign in user + + +@app.get("/approve") +async def approve(request: Request): + """ + barebones approval page, finalizes response after completion + """ + if "user" not in request.session: + redirect = request.url_for("login") + return RedirectResponse(redirect) + oauthreq: OAuthRequest = request.session["oauth"] + content = f""" + + +

{oauthreq.query.client_id} would like permissions.

+
+ + +
+ + + """ + return HTMLResponse(content) + + +@app.post("/approve") +async def approve_submit( + request: Request, + approval: int = Form(), + oauth: AuthServer = Depends(get_auth_server), +): + """ """ oauthreq = request.session["oauth"] - oauthreq.user = await auto_login() + oauthreq.user = request.session["user"] + if not approval: + # TODO: generate `permission_denied` response + return await approve(request) # process authorize request response = await oauth.create_authorization_response(oauthreq) return to_response(response) diff --git a/examples/screenshots/approve-form.png b/examples/screenshots/approve-form.png new file mode 100644 index 0000000000000000000000000000000000000000..966649fd01a4e2dad7d68a6adedd525f01f8040b GIT binary patch literal 3110 zcma)8`8U*$_a0dqg9sx_WKEeO`%-9Row5zamWatNl=VHbC#ET7&4`4FF*A&1Y=x|Y z5Hb^4lCm|J3Q6Mi>HYnF{(#SOp7Y#$&b_}q_uTVj*x6VJKqMg`5J@U|3Le$Q zIso=)J`$V&z|OaTfcTNcYXHF>47|q~@N&#OGysD@K=*A$ZqVtN0bo2;{V)i`CE@?_ z004Gg-%GMJE(LX_&rApw^F}w{qfu1 z05G3#A^=)1bLdYu=jP$(RE_Dv-?r6}6%Y{MmxP%FSDZSjC@U$V8{HvxqY3!hq$7Lu zsF<|MB{XFT013HM_JE26R6tfgu-6qL$EoL$7UQ4?EEc<*guqOaf!&>D_P6mICm}VL z57h}0f_&KKF9!#66<(+1!@Ah7ao3IIqxGdE^h3%}5{zl)c+K2z@?kc%%>lsrIJJLqpLFfm%fVG7 zKd55qlZuJ+Kb5`i@2jkw=rKt;mqV{Q?$|VZ1w`F zjoq8mOEc0uwZr}wxw*MfQBk=wZ1(!F0=k+#jy*4QNm>p~TxPSU^DIP7-~t-mEVjZi zGwe7Ud-14}sLec^J;`40XOD`U!(iOeEra(TRaeaYJ=QN)rf|m?jiqC>>_cP7bfa&`^iLq1^ingf6q64f z(RT=C>~5cDtTToR?HkfgK;S&bQy>iidwX$=YQx$e-_48@e${^CrBah$mmpBz67pOk z?MfjN_!DTQ0X?48ph)8$EZ-60!u<2=pN=wWllv=I#1pHsa~!biJ}2aG?}+YioL3n7 z-GQOS1K+TrNpjZH@XyLx;`nLHtjE?*b80HgfBUtpDBC`d)r*vTz11TzrAhBiR8nxW zxvZ*G@B4m(`^afVi-ZS(LdRqa;r^>EXHG=DS$#rX@9W(VfnxiHeOFv znSQ0jLyjyL)YqA64bgm~jcm=*6e4{b29x$@ixr^D|) z8&qji7B_}>I*;eM5Q3ky-b%p4i@0bUxw1OGJum0lW*v!UoZ+_^5Uf|oj4+(%zG2Z> zz)=7BS~M}^yAJ9!5&PMw-PAG|I}=HLd^Y3SJDVxq1I5bE_`$dUhpXY69hiEg-C$`^^I0rDt2{6~UFWpmu(BTbTqC%mw$4=dR^$UP6Q z@A-;XfFZRhPC@>ARBYCY$(alALL0Ja7o)bdxcJUfCOI;g%#~Cp5~JW7J%aimm?8JU-bN2JQmo7^LOb8dnuY5zsO1l!uOOuyZJ01|zV|2SR zpi>WLHeLs6-`|PJ%Om;H><6Baqo0M}nXeVB%rzfhTdSHAG>7bTK0q9QRh2#7h1F>o3uVDbzDmFRTBeZnE|s$Nm&4-4{7FP?xpU~; z$Lb7dbo;zebWGLFk=f~pH+O2{2h)q4jQeYoL`Qa=2$=NVX5Uu?e;HpQ&tfX z527r86f||8**_%Ivz=T;zLxSlpZ;QDBDef`i2nU-xiPHublPD&x6AWpMGJ+LWm)e7Qi=(4tB{b*`#XAijQ z`!G8cHyB+&!_<}->lF)Mt_!Hu@HtAk^R>9xg>Wu$awniEB;w( zL z<`Gh$uGaMk(om@Djg}BYlmn%6kVm8Bs4I7MO873M&NUelZTY85-hsjv2J}29h-$JXDc`ifj8c&ToV+X^TkL0j za6FMp?AQ`kcLQ*)JKXd0Juk2!e*%dGZJ&n>|N^=qMwlgo5 z2s)9;ocn^R0rk-GhTl)%4Ib^ey5BM!XH(*`LTPFws?hMG#NXR2O;-ezwqTdO)S;d0 z7IK1yMR(Y%h&&!##5?M&-;U*#Fe-{y%3VFO7Hb<{WYj1V5NnEdMxn>TA-rQA*@A5#b?AaDJN@Jr2~KmIvFTSh@P{rYal)@nVN0gM`4b&_VK3vF`l^v^KW;9 zRZUbmCN3Mh8Q^Y+?$ya^Up+i5ARp+|+WsLlc{LoCjWD_qXQlgL?0`|n;JbDR9hX6Q z_T(m+YA~A6QHvyH_*sK`kO$6U_KwE>uyURdFGcf1v_q}ZS8ZD}PBcv+j6%E~xp%H7 zba<>4I(EmO_wXDmzjgeNxU$b`s5ON=if9>75*meM7JYwpsyLL+(gUWSr*QmD<~$|` qOovGTBIjIv_0wQ3?0+kanizq;(?SxgSvkVL`QFOR#vB{L{$Bx`n049PQ3JgBVE zD1?lRWXrw{3e#XPmiWJV-sgRu@A;m4?m73|-~HWtzUOo9BP)xGN4X`rK_Jl4OQvXR z5C{xAoTowA4o6OLVh0FxWY_94#za(1=1>3twBI*SlHfaB1%Y{lWX_nxwF95iFPwmi z8{K)w0D#G-o+f9cAz(H>%pXhun9ekUg1FI{3ji>jpv!f*i(SIM2_XAQdcI^FYNh{? zji~j(OikD&HKcbYbq@eouL0HqB|GRr?ILh6QA!}R?*XjOL}?)bL0(=dbu<5b>ft#H z&1q0RaZXW>hIs(Qemvv@0CeGWx!BmDDo)W+L%Rp>AKD`oMUhUi=0wH;gPP%qI{6=T zT7T`$6P$N!jjCqpFQ52a)RxWcV_gvd1N1uYV{9flv%r^lH3-PIXc1ryB_g3vWYQ1- zEasa&t1eA2gV~W7RGgCi3)W$mt+y38PTvNGldeSNhr^-qUx32fCeSn1kMgRiu^N@G zA`rExp*T1@w{g}q@N@|Ps8R5^p>lg6V7bUnfCCtg56@`mTv*b5ILLRk4qq>)jFD@=EO-~isO71uMiF;NX z-mO_DN2qZK%AP@51(gmk0UphG#&T=4Ccl2-5Q~I?U%u5HXSUK+7+li*y@MdA6=bI> zD2h&`uF;_~M)=k}HZf%-ebXcg@G9DAZ@75!u${g3(dX{B48^oOUvyKtYhq2OnCts2 z<^r`MsB!&SA2U2W{A)SkN!$1RuNASr*UlVO1gf2dIN(t|fWHw=;9R_?jiIHzodw!t zvMK^)Y#C8Y(^EExC1r+q=sLT)UqETIx}Pu@tfWM;kpPlj<{%}3E$S z2zfQ)jF4AUP|J?To(v!xaLbq-Rn*~u1C>txk82&BE-T3R)$Re|O2RHRfSdG5oM)a4 z(!IiT}x|=EtD|BSzvfwsxXM#$z6T^0HW1^*p5~gu!4G-#tp? zwxSVm9|<@#-55nXw2fIZ+`^fSW@2OR1mV7DXG!FOs|%C)#U&6zHfd6b5FT>{NhL3L zm2^|Hv9#2R0&>kMm`smh7hbwMlLNJmqhv)`Yrwqq#O&(db%Dk=vHAmp`ee=8-1^ii z7_%x?Y(+_QU65DFHr%#qmcu^1+_BQPKF;VJ)cOn}Ite1{-{(8;F7BxFi-_56d>4u*S2(i`iQ(^{ zI-_+^`%w#!a%V&6|&4YgwmQ`^&u+a8wmLh+Sn7zpdtmV}|s z2uw6^ug?>U?M#v4zl~Q0mu<4G_RtMxS-xM+#vJ(OYJvkm(aUY5;|752rJApJ4|$9H zr`6FT>&wH!ArIGbmo_0V_kF+n%Q>ca50b66w^2hGsGpA2ANLA4RrM;rO4wW6s7rM}qKwy){ ze3@PUs?Lc3cnKdm!(7yfC@3WD0jpZMPm{-N3M6;j_|&`v0R5%QxILhN7oBK@;jg^ zFG~cTSfY#TA6kwaQ{{s#BBNNopp~imvnt27{F5D@rt9cwnfDl%C~aKoMyXRaWL>do z2{HOb6&X&9UX(xoJp~+RD1Yn?!tohX4QZs;Ik%#%HI1&G-wG165bE)Lk7ZYBvg(hk zs_O<)HljuR)e#Oy6zNCUz+wVt?bfd&#w>`EDv^~f8K)xx za<6;b9f?7qx2p&xJ|NWAlpxJB~Vdw&~WuQb=<*D|?MF2Q8tw^I;!|L`I07J0+f z!xF<($5Z+mAHP98v@a6x$cx{qfNZ-orJ!)WBgwxdexG9#CpcS!#zLaqMJifb!8RW| zXRc{Z#c=Prk&nd_cqF#2z-E)TiEf90=rHpAYkoZ*A>#xPz^E6pR z?{YhCH3km*)9z(KgU%dS&r+=!&=gGXjh$g?MMy1e^_}5Wb00DxTLly!-yl%2OQ2puAhPM~&z;A4)q|u$gu`yk zv@y%||MA)D6<$9MTQ6oB>cXB68_=ptcT=v*t={U@h|mijn&df{Ifb?6g!v#BPGfkE zmP2nFqXcG%ynxSL#q{1S2x?V>?$;q8Zly?Hsuue>GIejW->#g&s8EGlUo)`9^IGtuceVMYXck*#6RjG?P=$Y#4Tc#3js!=XesH$kHgPXdWv?RxS z=zjBbIT*if+{%{|?`cffvMv6@GD&B7-Ai`v$f%avfgaN=ks(B@Q==y;eh8#chnR1h zZ$630>iOK0*ftTFP^$rVek#Jy43NJJS4Y=v#msMSG`cx_BB^)^O%8di;M@e}o-xg4 z#a+dxBvp0}*e0|`BAtv^HkNkrO|zi#$fk`4{fy@i@Vc4PA;DX3L+vZW7u#c-R4t!J z%JpqekGJie%Qvd8J9%smyzzl&-@nu(N2*~J37u04mNIC*Ln@=xHhcbF<1-SFzql#* zH6nj2kkzb&&Q)#)Kb$z-It6iCeNJut`Du?=9Ayr{WwmzdbcgwMmV&HQD09>|0VQ5C z?{8p9#S_)#z1pDn-?B7(SS?(joBjr4HtjsOQ;rrF7f@QuJqt}mS_nTZyZ*+>7E$q6 z8cM8C#;!q_4{v35UDqS`uRcrI+1`CetG@c|7g;!re&L4b$B`2pHk<6Tp1vF>pQka>nK5Pj#HkbP7AT#WYFec>ADsMZSQ`LW4$42lKO`m zRQ*57d$(TVPBU$ucpa!~ePZXpezipj8PO{TJLs;z(5ExS&*(K`n!U6D literal 0 HcmV?d00001 diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py index a57334d..3af6b51 100644 --- a/examples/shared/__init__.py +++ b/examples/shared/__init__.py @@ -5,6 +5,7 @@ from contextlib import asynccontextmanager import os +from typing import Optional from aioauth.server import AuthorizationServer from sqlmodel import SQLModel, select @@ -12,7 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from .config import load_config -from .models import Client from .storage import BackendStore, User __all__ = [ @@ -21,7 +21,7 @@ "engine", "config", "settings", - "auto_login", + "try_login", "lifespan", ] @@ -35,13 +35,16 @@ settings = config.settings -async def auto_login() -> User: +async def try_login(username: str, password: str) -> Optional[User]: """ - return test user-account simulating login + try username and password against user fixtures in database """ async with AsyncSession(engine) as conn: - sql = select(User).where(User.username == "test") - return (await conn.exec(sql)).one() + sql = select(User).where( + User.username == username and User.password == password + ) + record = await conn.exec(sql) + return record.first() @asynccontextmanager @@ -55,20 +58,10 @@ async def lifespan(*_): await conn.run_sync(SQLModel.metadata.create_all) # create test records async with AsyncSession(engine) as session: - user = User( - username="test", - password="password", - ) - client = Client( - client_id="test_client", - client_secret="password", - grant_types="authorization_code,refresh_token", - redirect_uris="http://localhost:3000/redirect", - response_types="code", - scope="email", - ) - session.add(user) - session.add(client) + for user in config.fixtures.users: + session.add(user) + for client in config.fixtures.clients: + session.add(client) await session.commit() yield # close connections on app closure From ad273e7c75b5730c3520330bfcf65b064265ccec Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Fri, 20 Dec 2024 13:46:44 -0700 Subject: [PATCH 35/57] feat: allow passing extra arbitrary data within oauth request --- aioauth/requests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aioauth/requests.py b/aioauth/requests.py index e69d560..238adf9 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -73,3 +73,4 @@ class Request(Generic[UserType]): url: str = "" user: Optional[UserType] = None settings: Settings = field(default_factory=Settings) + extra: dict = field(default_factory=dict) From bda9a0322f8c166904ca597be066c59219c74073 Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Thu, 5 Dec 2024 14:51:30 -0700 Subject: [PATCH 36/57] feat: support for building custom error-responses --- aioauth/errors.py | 8 +++++ aioauth/types.py | 1 + aioauth/utils.py | 92 +++++++++++++++++++++++++++-------------------- 3 files changed, 62 insertions(+), 39 deletions(-) diff --git a/aioauth/errors.py b/aioauth/errors.py index 2837dd4..bb62daf 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -214,3 +214,11 @@ class UnsupportedTokenTypeError(Generic[UserType], OAuth2Error[UserType]): """ error: ErrorType = "unsupported_token_type" + + +class AccessDeniedError(Generic[UserType], OAuth2Error[UserType]): + """ + The resource owner or authorization server denied the request + """ + + error: ErrorType = "access_denied" diff --git a/aioauth/types.py b/aioauth/types.py index f7defe2..986fec4 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -36,6 +36,7 @@ "method_is_not_allowed", "server_error", "temporarily_unavailable", + "access_denied", ] diff --git a/aioauth/utils.py b/aioauth/utils.py index cd9cf0d..d380e71 100644 --- a/aioauth/utils.py +++ b/aioauth/utils.py @@ -227,6 +227,57 @@ def create_s256_code_challenge(code_verifier: str) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() +def build_error_response( + exc: Exception, + request: Request, + skip_redirect_on_exc: Tuple[Type[OAuth2Error], ...] = (OAuth2Error,), +) -> Response: + """ + Generate an OAuth HTTP response from the given exception + + Args: + exc: Exception used to generate HTTP response + request: oauth request object + skip_redirect_on_exc: Exception types to skip redirect on + Returns: + OAuth HTTP response + """ + error: Union[TemporarilyUnavailableError, ServerError] + if isinstance(exc, skip_redirect_on_exc): + content = ErrorResponse(error=exc.error, description=exc.description) + log.debug("%s %r", exc, request) + return Response( + content=asdict(content), + status_code=exc.status_code, + headers=exc.headers, + ) + if isinstance(exc, OAuth2Error): + log.debug("%s %r", exc, request) + query: Dict[str, str] = {"error": exc.error} + if exc.description: + query["error_description"] = exc.description + if request.settings.ERROR_URI: + query["error_uri"] = request.settings.ERROR_URI + if exc.state: + query["state"] = exc.state + location = build_uri(request.query.redirect_uri, query) + return Response( + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict({"location": location}), + ) + error = ServerError( + request=request, + description=str(exc) if request.settings.DEBUG else "", + ) + log.exception("Exception caught while processing request.", exc_info=exc) + content = ErrorResponse(error=error.error, description=error.description) + return Response( + content=asdict(content), + status_code=error.status_code, + headers=error.headers, + ) + + def catch_errors_and_unavailability( skip_redirect_on_exc: Tuple[Type[OAuth2Error], ...] = (OAuth2Error,) ) -> Callable[..., Callable[..., Coroutine[Any, Any, Response]]]: @@ -242,49 +293,12 @@ def catch_errors_and_unavailability( def decorator(f) -> Callable[..., Coroutine[Any, Any, Response]]: @functools.wraps(f) async def wrapper(self, request: Request, *args, **kwargs) -> Response: - error: Union[TemporarilyUnavailableError, ServerError] - try: response = await f(self, request, *args, **kwargs) - except skip_redirect_on_exc as exc: - content = ErrorResponse(error=exc.error, description=exc.description) - log.debug("%s %r", exc, request) - return Response( - content=asdict(content), - status_code=exc.status_code, - headers=exc.headers, - ) - except OAuth2Error as exc: - log.debug("%s %r", exc, request) - query: Dict[str, str] = { - "error": exc.error, - } - if exc.description: - query["error_description"] = exc.description - if request.settings.ERROR_URI: - query["error_uri"] = request.settings.ERROR_URI - if exc.state: - query["state"] = exc.state - location = build_uri(request.query.redirect_uri, query) - return Response( - status_code=HTTPStatus.FOUND, - headers=HTTPHeaderDict({"location": location}), - ) except Exception as exc: - error = ServerError( - request=request, - description=str(exc) if request.settings.DEBUG else "", - ) - log.exception("Exception caught while processing request.") - content = ErrorResponse( - error=error.error, description=error.description + response = build_error_response( + exc=exc, request=request, skip_redirect_on_exc=skip_redirect_on_exc ) - return Response( - content=asdict(content), - status_code=error.status_code, - headers=error.headers, - ) - return response return wrapper From 888e0a98812c2640e2c29efdc6a018bcc469a8e0 Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Mon, 23 Dec 2024 13:28:41 -0700 Subject: [PATCH 37/57] feat: access-denied exception on scope denial --- examples/fastapi_example.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index e3fb5df..6e823ed 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -15,10 +15,12 @@ from sqlmodel.ext.asyncio.session import AsyncSession from aioauth.collections import HTTPHeaderDict +from aioauth.errors import AccessDeniedError from aioauth.requests import Post, Query from aioauth.requests import Request as OAuthRequest from aioauth.responses import Response as OAuthResponse from aioauth.types import RequestMethod +from aioauth.utils import build_error_response from shared import AuthServer, BackendStore, engine, settings, try_login, lifespan @@ -174,14 +176,18 @@ async def approve_submit( approval: int = Form(), oauth: AuthServer = Depends(get_auth_server), ): - """ """ + """ + scope approval form submission handler + """ oauthreq = request.session["oauth"] oauthreq.user = request.session["user"] if not approval: - # TODO: generate `permission_denied` response - return await approve(request) - # process authorize request - response = await oauth.create_authorization_response(oauthreq) + # generate error response on deny + error = AccessDeniedError(oauthreq, 'User rejected scopes') + response = build_error_response(error, oauthreq, skip_redirect_on_exc=()) + else: + # process authorize request + response = await oauth.create_authorization_response(oauthreq) return to_response(response) From fcab1effb326c102d7c8ff3065d25612456e8e1f Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Tue, 24 Dec 2024 11:57:10 +0400 Subject: [PATCH 38/57] fix: black reformatting issue to pass the pipeline --- examples/fastapi_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 6e823ed..6bdc8d0 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -183,7 +183,7 @@ async def approve_submit( oauthreq.user = request.session["user"] if not approval: # generate error response on deny - error = AccessDeniedError(oauthreq, 'User rejected scopes') + error = AccessDeniedError(oauthreq, "User rejected scopes") response = build_error_response(error, oauthreq, skip_redirect_on_exc=()) else: # process authorize request From 8ad300102888159df7a4fe837d71d5565f8022f8 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 18 Jan 2025 14:40:06 +0400 Subject: [PATCH 39/57] feat: refactor storage methods to use keyword-only arguments for clarity This PR removes the use of `Unpack` and `TypedDict` in favor of named parameters, as these introduce additional "hacks" to support older Python versions, making the code more complex. Discussion thread: https://github.com/aliev/aioauth/pull/106#issuecomment-2484151019# Please enter the commit message for your changes. Lines starting --- aioauth/storage.py | 139 +++++++++++++++---------------------- examples/shared/storage.py | 25 +++++-- tests/classes.py | 118 +++++++++++++++++-------------- 3 files changed, 139 insertions(+), 143 deletions(-) diff --git a/aioauth/storage.py b/aioauth/storage.py index a51bedc..6682626 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -10,8 +10,7 @@ ---- """ -import sys -from typing import TYPE_CHECKING, Optional, Generic +from typing import Optional, Generic from .models import AuthorizationCode, Client, Token from .types import CodeChallengeMethod, TokenType @@ -19,83 +18,17 @@ from .requests import Request from .types import UserType -if sys.version_info >= (3, 11): - from typing import NotRequired, Unpack -else: - from typing_extensions import NotRequired, Unpack - -from typing import TypedDict as _TypedDict - -# NOTE: workaround for generic TypedDict support -# https://github.com/python/cpython/issues/89026 -if TYPE_CHECKING: - - class TypedDict(Generic[UserType], _TypedDict): ... - -else: - - class TypedDict(Generic[UserType]): ... - - -class GetAuthorizationCodeArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - code: str - - -class GetClientArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - client_secret: NotRequired[Optional[str]] - - -class GetIdTokenArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - scope: str - response_type: Optional[str] - redirect_uri: str - nonce: Optional[str] - - -class CreateAuthorizationCodeArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - scope: str - response_type: str - redirect_uri: str - code_challenge_method: Optional[CodeChallengeMethod] - code_challenge: Optional[str] - code: str - nonce: NotRequired[Optional[str]] - - -class CreateTokenArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - scope: str - access_token: str - refresh_token: Optional[str] - - -class GetTokenArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - token_type: Optional[TokenType] # default is "refresh_token" - access_token: Optional[str] # default is None - refresh_token: Optional[str] # default is None - - -class RevokeTokenArgs(TypedDict[UserType]): - request: Request[UserType] - client_id: str - refresh_token: Optional[str] - token_type: Optional[TokenType] - access_token: Optional[str] - class TokenStorage(Generic[UserType]): - async def create_token(self, **kwargs: Unpack[CreateTokenArgs[UserType]]) -> Token: + async def create_token( + self, + *, + request: Request[UserType], + client_id: str, + scope: str, + access_token: str, + refresh_token: Optional[str] = None, + ) -> Token: """Generates a user token and stores it in the database. Used by: @@ -120,7 +53,13 @@ async def create_token(self, **kwargs: Unpack[CreateTokenArgs[UserType]]) -> Tok raise NotImplementedError("Method create_token must be implemented") async def get_token( - self, **kwargs: Unpack[GetTokenArgs[UserType]] + self, + *, + request: Request[UserType], + client_id: str, + token_type: Optional[TokenType] = None, + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, ) -> Optional[Token]: """Gets existing token from the database. @@ -138,7 +77,15 @@ async def get_token( """ raise NotImplementedError("Method get_token must be implemented") - async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[UserType]]) -> None: + async def revoke_token( + self, + *, + request: Request[UserType], + client_id: str, + refresh_token: Optional[str] = None, + token_type: Optional[TokenType] = None, + access_token: Optional[str] = None, + ) -> None: """Revokes a token from the database.""" raise NotImplementedError @@ -146,7 +93,16 @@ async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[UserType]]) -> Non class AuthorizationCodeStorage(Generic[UserType]): async def create_authorization_code( self, - **kwargs: Unpack[CreateAuthorizationCodeArgs[UserType]], + *, + request: Request[UserType], + client_id: str, + scope: str, + response_type: str, + redirect_uri: str, + code: str, + code_challenge_method: Optional[CodeChallengeMethod] = None, + code_challenge: Optional[str] = None, + nonce: Optional[str] = None, ) -> AuthorizationCode: """Generates an authorization token and stores it in the database. @@ -172,7 +128,10 @@ async def create_authorization_code( async def get_authorization_code( self, - **kwargs: Unpack[GetAuthorizationCodeArgs[UserType]], + *, + request: Request[UserType], + client_id: str, + code: str, ) -> Optional[AuthorizationCode]: """Gets existing authorization code from the database if it exists. @@ -196,7 +155,10 @@ async def get_authorization_code( async def delete_authorization_code( self, - **kwargs: Unpack[GetAuthorizationCodeArgs[UserType]], + *, + request: Request[UserType], + client_id: str, + code: str, ) -> None: """Deletes authorization code from database. @@ -216,7 +178,10 @@ async def delete_authorization_code( class ClientStorage(Generic[UserType]): async def get_client( self, - **kwargs: Unpack[GetClientArgs[UserType]], + *, + request: Request[UserType], + client_id: str, + client_secret: Optional[str] = None, ) -> Optional[Client[UserType]]: """Gets existing client from the database if it exists. @@ -256,7 +221,13 @@ async def get_user(self, request: Request[UserType]) -> Optional[UserType]: class IDTokenStorage(Generic[UserType]): async def get_id_token( self, - **kwargs: Unpack[GetIdTokenArgs[UserType]], + *, + request: Request[UserType], + client_id: str, + scope: str, + redirect_uri: str, + response_type: Optional[str] = None, + nonce: Optional[str] = None, ) -> str: """Returns an id_token. For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 `_. diff --git a/examples/shared/storage.py b/examples/shared/storage.py index efa0f94..00e3c8a 100644 --- a/examples/shared/storage.py +++ b/examples/shared/storage.py @@ -31,6 +31,7 @@ def __init__(self, session: AsyncSession): async def get_client( self, + *, request: Request[User], client_id: str, client_secret: Optional[str] = None, @@ -61,15 +62,16 @@ def __init__(self, session: AsyncSession): async def create_authorization_code( self, + *, request: Request[User], client_id: str, scope: str, response_type: str, redirect_uri: str, - code_challenge_method: Optional[CodeChallengeMethod], - code_challenge: Optional[str], code: str, - **kwargs, + code_challenge_method: Optional[CodeChallengeMethod] = None, + code_challenge: Optional[str] = None, + nonce: Optional[str] = None, ) -> AuthorizationCode: """""" auth_code = AuthorizationCode( @@ -83,7 +85,6 @@ async def create_authorization_code( code_challenge=code_challenge, code_challenge_method=code_challenge_method, user=request.user, - **kwargs, ) record = AuthCodeTable( code=auth_code.code, @@ -104,7 +105,11 @@ async def create_authorization_code( return auth_code async def get_authorization_code( - self, request: Request[User], client_id: str, code: str + self, + *, + request: Request[User], + client_id: str, + code: str, ) -> Optional[AuthorizationCode]: """ """ async with self.session: @@ -125,7 +130,11 @@ async def get_authorization_code( ) async def delete_authorization_code( - self, request: Request[User], client_id: str, code: str + self, + *, + request: Request[User], + client_id: str, + code: str, ) -> None: """ """ async with self.session: @@ -147,7 +156,7 @@ async def create_token( client_id: str, scope: str, access_token: str, - refresh_token: Optional[str], + refresh_token: Optional[str] = None, ) -> Token: """ """ token = Token( @@ -179,6 +188,7 @@ async def create_token( async def get_token( self, + *, request: Request[User], client_id: str, token_type: Optional[TokenType] = "refresh_token", @@ -208,6 +218,7 @@ async def get_token( async def revoke_token( self, + *, request: Request[User], client_id: str, token_type: Optional[TokenType] = "refresh_token", diff --git a/tests/classes.py b/tests/classes.py index 3495380..4e75039 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -1,5 +1,4 @@ import time -import sys from typing import Dict, List, Optional, Type from functools import cached_property @@ -12,22 +11,14 @@ from aioauth.requests import Request from aioauth.response_type import ResponseTypeBase from aioauth.server import AuthorizationServer -from aioauth.storage import ( - BaseStorage, - CreateAuthorizationCodeArgs, - CreateTokenArgs, - GetAuthorizationCodeArgs, - GetClientArgs, - GetIdTokenArgs, - GetTokenArgs, - RevokeTokenArgs, +from aioauth.storage import BaseStorage +from aioauth.types import ( + CodeChallengeMethod, + GrantType, + ResponseType, + TokenType, + UserType, ) -from aioauth.types import GrantType, ResponseType - -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack @dataclass(frozen=True) @@ -60,22 +51,25 @@ def _get_by_client_id(self, client_id: str): async def get_client( self, - **kwargs: Unpack[GetClientArgs[User]], + *, + request: Request[UserType], + client_id: str, + client_secret: Optional[str] = None, ) -> Optional[Client[User]]: - client_secret = kwargs.get("client_secret") - client_id = kwargs["client_id"] - if client_secret is not None: return self._get_by_client_secret(client_id, client_secret) return self._get_by_client_id(client_id) - async def create_token(self, **kwargs: Unpack[CreateTokenArgs[User]]): - client_id = kwargs["client_id"] - request = kwargs["request"] - access_token = kwargs["access_token"] - refresh_token = kwargs["refresh_token"] - scope = kwargs["scope"] + async def create_token( + self, + *, + request: Request[UserType], + client_id: str, + scope: str, + access_token: str, + refresh_token: Optional[str] = None, + ): token: Token[User] = Token( client_id=client_id, expires_in=request.settings.TOKEN_EXPIRES_IN, @@ -89,21 +83,31 @@ async def create_token(self, **kwargs: Unpack[CreateTokenArgs[User]]): self.tokens.append(token) return token - async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[User]]) -> None: + async def revoke_token( + self, + *, + request: Request[UserType], + client_id: str, + refresh_token: Optional[str] = None, + token_type: Optional[TokenType] = None, + access_token: Optional[str] = None, + ) -> None: tokens = self.tokens - refresh_token = kwargs["refresh_token"] - access_token = kwargs["access_token"] for key, token_ in enumerate(tokens): if token_.refresh_token == refresh_token: tokens[key] = replace(token_, revoked=True) elif token_.access_token == access_token: tokens[key] = replace(token_, revoked=True) - async def get_token(self, **kwargs: Unpack[GetTokenArgs[User]]) -> Optional[Token]: - refresh_token = kwargs["refresh_token"] - access_token = kwargs["access_token"] - client_id = kwargs["client_id"] - + async def get_token( + self, + *, + request: Request[UserType], + client_id: str, + token_type: Optional[TokenType] = None, + access_token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> Optional[Token]: for token_ in self.tokens: if ( refresh_token is not None @@ -132,17 +136,17 @@ async def get_user(self, request: Request[User]) -> Optional[User]: async def create_authorization_code( self, - **kwargs: Unpack[CreateAuthorizationCodeArgs[User]], + *, + request: Request[UserType], + client_id: str, + scope: str, + response_type: str, + redirect_uri: str, + code: str, + code_challenge_method: Optional[CodeChallengeMethod] = None, + code_challenge: Optional[str] = None, + nonce: Optional[str] = None, ): - request = kwargs["request"] - nonce = kwargs.get("nonce") - code = kwargs["code"] - client_id = kwargs["client_id"] - redirect_uri = kwargs["redirect_uri"] - response_type = kwargs["response_type"] - scope = kwargs["scope"] - code_challenge_method = kwargs["code_challenge_method"] - code_challenge = kwargs["code_challenge"] authorization_code = AuthorizationCode( code=code, client_id=client_id, @@ -161,11 +165,11 @@ async def create_authorization_code( async def get_authorization_code( self, - **kwargs: Unpack[GetAuthorizationCodeArgs[User]], + *, + request: Request[UserType], + client_id: str, + code: str, ) -> Optional[AuthorizationCode]: - code = kwargs["code"] - client_id = kwargs["client_id"] - for authorization_code in self.authorization_codes: if ( authorization_code.code == code @@ -175,10 +179,11 @@ async def get_authorization_code( async def delete_authorization_code( self, - **kwargs: Unpack[GetAuthorizationCodeArgs[User]], + *, + request: Request[UserType], + client_id: str, + code: str, ): - code = kwargs["code"] - client_id = kwargs["client_id"] authorization_codes = self.authorization_codes for authorization_code in authorization_codes: if ( @@ -187,7 +192,16 @@ async def delete_authorization_code( ): authorization_codes.remove(authorization_code) - async def get_id_token(self, **kwargs: Unpack[GetIdTokenArgs[User]]) -> str: + async def get_id_token( + self, + *, + request: Request[UserType], + client_id: str, + scope: str, + redirect_uri: str, + response_type: Optional[str] = None, + nonce: Optional[str] = None, + ) -> str: return "generated id token" From 301036b693549007f8f29e9a71f1aebae204bc0f Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Mon, 20 Jan 2025 20:56:08 -0700 Subject: [PATCH 40/57] chore: merge with master, refactor for 2.0.0 release --- aioauth/response_type.py | 12 ++-- aioauth/server.py | 120 ++++++++++++++++++++++++++++-------- examples/fastapi_example.py | 21 ++++--- 3 files changed, 116 insertions(+), 37 deletions(-) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 6bd1063..39031f7 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -38,7 +38,9 @@ class ResponseTypeBase(Generic[UserType]): def __init__(self, storage: BaseStorage[UserType]): self.storage = storage - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request( + self, request: Request[UserType], skip_user: bool = False + ) -> Client[UserType]: state = request.query.state code_challenge_methods: Tuple[CodeChallengeMethod, ...] = get_args( @@ -90,7 +92,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] if not client.check_scope(request.query.scope): raise InvalidScopeError[UserType](request=request, state=state) - if not request.user: + if not skip_user and not request.user: raise InvalidClientError[UserType]( request=request, description="User is not authorized", state=state ) @@ -156,8 +158,10 @@ async def create_authorization_response( class ResponseTypeIdToken(ResponseTypeBase[UserType]): - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: - client = await super().validate_request(request) + async def validate_request( + self, request: Request[UserType], skip_user: bool = False + ) -> Client[UserType]: + client = await super().validate_request(request, skip_user) # nonce is required for id_token if not request.query.nonce: diff --git a/aioauth/server.py b/aioauth/server.py index 0f133b3..9d9bddc 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -17,10 +17,11 @@ ---- """ -from dataclasses import asdict +from dataclasses import asdict, dataclass from http import HTTPStatus from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union, get_args +from .models import Client from .requests import Request from .types import UserType from .storage import BaseStorage @@ -71,6 +72,20 @@ ) +@dataclass +class AuthorizationState(Generic[UserType]): + """AuthorizationServer state object used in Authorization Code process.""" + + request: Request[UserType] + """OAuth2.0 Authorization Code Request Object""" + + response_type_list: List[ResponseType] + """Supported ResponseTypes Collected During Initial Request Validation""" + + grants: List[Tuple[ResponseTypeAuthorizationCode[UserType], Client]] + """Collection of Supported GrantType Handlers and The Parsed Clients""" + + class AuthorizationServer(Generic[UserType]): """Interface for initializing an OAuth 2.0 server.""" @@ -341,13 +356,14 @@ async def token(request: fastapi.Request) -> fastapi.Response: InvalidRedirectURIError, ) ) - async def create_authorization_response( + async def validate_authorization_request( self, request: Request[UserType] - ) -> Response: + ) -> Union[Response, AuthorizationState]: """ Endpoint to interact with the resource owner and obtain an - authorization grant. - Validate authorization request and create authorization response. + authoriation grant. + Validate authorization request and return valid authorization + state for later response generation. For more information see `RFC6749 section 4.1.1 `_. @@ -365,8 +381,10 @@ async def create_authorization_response( async def authorize(request: fastapi.Request) -> fastapi.Response: # Converts a fastapi.Request to an aioauth.Request. oauth2_request: aioauth.Request = await to_oauth2_request(request) + # Validate the oauth request + auth_state: aioauth.AuthState = await server.validate_authorization_request(oauth2_request) # Creates the response via this function call. - oauth2_response: aioauth.Response = await server.create_authorization_response(oauth2_request) + oauth2_response: aioauth.Response = await server.create_authorization_response(auth_state) # Converts an aioauth.Response to a fastapi.Response. response: fastapi.Response = await to_fastapi_response(oauth2_response) return response @@ -375,25 +393,12 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: request: An :py:class:`aioauth.requests.Request` object. Returns: - response: An :py:class:`aioauth.responses.Response` object. + state: An :py:class:`aioauth.server.AuthState` object. """ self.validate_request(request, ["GET", "POST"]) response_type_list = enforce_list(request.query.response_type) response_type_classes = set() - - # Combined responses - responses = {} - - # URI fragment - fragment = {} - - # URI query params - query = {} - - # Response content - content = {} - state = request.query.state if not response_type_list: @@ -403,9 +408,6 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: state=state, ) - if state: - responses["state"] = state - for response_type in response_type_list: ResponseTypeClass = self.response_types.get(response_type) if ResponseTypeClass: @@ -414,9 +416,79 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: if not response_type_classes: raise UnsupportedResponseTypeError[UserType](request=request, state=state) + auth_state = AuthorizationState(request, response_type_list, grants=[]) for ResponseTypeClass in response_type_classes: response_type = ResponseTypeClass(storage=self.storage) - client = await response_type.validate_request(request) + client = await response_type.validate_request(request, skip_user=True) + auth_state.grants.append((response_type, client)) + return auth_state + + @catch_errors_and_unavailability( + skip_redirect_on_exc=( + MethodNotAllowedError, + InvalidClientError, + InvalidRedirectURIError, + ) + ) + async def create_authorization_response( + self, + auth_state: AuthorizationState[UserType], + ) -> Response: + """ + Endpoint to interact with the resource owner and obtain an + authorization grant. + Create an authorization response after validation. + For more information see + `RFC6749 section 4.1.1 `_. + + Example: + Below is an example utilizing FastAPI as the server framework. + .. code-block:: python + + from aioauth.fastapi.utils import to_oauth2_request, to_fastapi_response + + @app.post("/authorize") + async def authorize(request: fastapi.Request) -> fastapi.Response: + # Converts a fastapi.Request to an aioauth.Request. + oauth2_request: aioauth.Request = await to_oauth2_request(request) + # Validate the oauth request + auth_state: aioauth.AuthState = await server.validate_authorization_request(oauth2_request) + # Creates the response via this function call. + oauth2_response: aioauth.Response = await server.create_authorization_response(auth_state) + # Converts an aioauth.Response to a fastapi.Response. + response: fastapi.Response = await to_fastapi_response(oauth2_response) + return response + + Args: + auth_state: An :py:class:`aioauth.server.AuthState` object. + + Returns: + response: An :py:class:`aioauth.responses.Response` object. + """ + request = auth_state.request + state = auth_state.request.query.state + response_type_list = auth_state.response_type_list + if request.user: + raise InvalidClientError[UserType]( + request=request, description="User is not authorized", state=state + ) + + # Combined responses + responses = {} + + # URI fragment + fragment = {} + + # URI query params + query = {} + + # Response content + content = {} + + if state: + responses["state"] = state + + for response_type, client in auth_state.grants: response = await response_type.create_authorization_response( request, client ) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 6bdc8d0..9a2f66e 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -6,7 +6,6 @@ import json import html -from http import HTTPStatus from typing import Optional, cast from fastapi import FastAPI, Form, Request, Depends, Response @@ -19,6 +18,7 @@ from aioauth.requests import Post, Query from aioauth.requests import Request as OAuthRequest from aioauth.responses import Response as OAuthResponse +from aioauth.server import AuthorizationState as OAuthState from aioauth.types import RequestMethod from aioauth.utils import build_error_response @@ -74,10 +74,13 @@ async def authorize( oauth2 authorization endpoint using aioauth """ oauthreq = await to_request(request) - response = await oauth.create_authorization_response(oauthreq) - if response.status_code == HTTPStatus.UNAUTHORIZED: - request.session["oauth"] = oauthreq + auth_state = await oauth.validate_authorization_request(oauthreq) + if isinstance(auth_state, OAuthResponse): + return to_response(auth_state) + if "user" not in request.session: + request.session["oauth"] = auth_state return RedirectResponse("/login") + response = await oauth.create_authorization_response(auth_state) return to_response(response) @@ -155,11 +158,11 @@ async def approve(request: Request): if "user" not in request.session: redirect = request.url_for("login") return RedirectResponse(redirect) - oauthreq: OAuthRequest = request.session["oauth"] + state: OAuthState = request.session["oauth"] content = f""" -

{oauthreq.query.client_id} would like permissions.

+

{state.request.query.client_id} would like permissions.

@@ -179,15 +182,15 @@ async def approve_submit( """ scope approval form submission handler """ - oauthreq = request.session["oauth"] - oauthreq.user = request.session["user"] + state: OAuthState = request.session["oauth"] + oauthreq: OAuthRequest = state.request if not approval: # generate error response on deny error = AccessDeniedError(oauthreq, "User rejected scopes") response = build_error_response(error, oauthreq, skip_redirect_on_exc=()) else: # process authorize request - response = await oauth.create_authorization_response(oauthreq) + response = await oauth.create_authorization_response(state) return to_response(response) From 1d5e6b72d8159c0ad52ff7769619228bc836160a Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Thu, 23 Jan 2025 23:21:58 +0400 Subject: [PATCH 41/57] chore: removed unnecessary Literal from errors --- aioauth/errors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aioauth/errors.py b/aioauth/errors.py index bb62daf..b3f4041 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -11,8 +11,6 @@ from http import HTTPStatus from typing import Generic, Optional from urllib.parse import urljoin -from typing_extensions import Literal - from .requests import Request from .collections import HTTPHeaderDict @@ -72,7 +70,7 @@ class InvalidRequestError(Generic[UserType], OAuth2Error[UserType]): otherwise malformed. """ - error: Literal["invalid_request"] = "invalid_request" + error: ErrorType = "invalid_request" class InvalidClientError(Generic[UserType], OAuth2Error[UserType]): @@ -152,7 +150,7 @@ class MismatchingStateError(Generic[UserType], OAuth2Error[UserType]): """Unable to securely verify the integrity of the request and response.""" description = "CSRF Warning! State not equal in request and response." - error: Literal["mismatching_state"] = "mismatching_state" + error: ErrorType = "mismatching_state" class UnauthorizedClientError(Generic[UserType], OAuth2Error[UserType]): From 08939bfb91c745bff29246484b51877168bbf276 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 02:37:55 +0400 Subject: [PATCH 42/57] fix: codecov support, docs simplify --- .github/workflows/ci.yml | 10 ++-- Makefile | 3 +- README.md | 2 +- docs/source/conf.py | 12 ++-- docs/source/contents.rst | 14 ----- docs/source/index.rst | 2 +- docs/source/sections/examples/aiohttp.rst | 2 - docs/source/sections/examples/fastapi.rst | 58 ------------------- docs/source/sections/using/configuration.rst | 48 --------------- .../source/sections/using/server_database.rst | 2 - 10 files changed, 16 insertions(+), 137 deletions(-) delete mode 100644 docs/source/sections/examples/aiohttp.rst delete mode 100644 docs/source/sections/examples/fastapi.rst delete mode 100644 docs/source/sections/using/configuration.rst delete mode 100644 docs/source/sections/using/server_database.rst diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa6c7f9..d893076 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,14 +24,14 @@ jobs: - name: Install dependencies run: | make dev-install - pip install codecov - name: Run lint run: | make lint - name: Run tests run: | make test - - name: Upload test coverage - run: codecov - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/Makefile b/Makefile index c6ceac4..6c858f1 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,8 @@ lint: ## check style with flake8 pre-commit run --all-files test: ## run tests quickly with the default Python - pytest tests + coverage run -m pytest tests + coverage xml -o junit.xml release: dist ## package and upload a release twine upload dist/* diff --git a/README.md b/README.md index 2a1801a..d2a7a8b 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Coverage](https://badgen.net/codecov/c/github/aliev/aioauth)](https://app.codecov.io/gh/aliev/aioauth) [![License](https://img.shields.io/github/license/aliev/aioauth)](https://github.com/aliev/aioauth/blob/master/LICENSE) [![PyPi](https://badgen.net/pypi/v/aioauth)](https://pypi.org/project/aioauth/) -[![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/) +[![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-390/) `aioauth` implements [OAuth 2.0 protocol](https://tools.ietf.org/html/rfc6749) and can be used in asynchronous frameworks like [FastAPI / Starlette](https://github.com/tiangolo/fastapi), [aiohttp](https://github.com/aio-libs/aiohttp). It can work with any databases like `MongoDB`, `PostgreSQL`, `MySQL` and ORMs like [gino](https://python-gino.org/), [sqlalchemy](https://www.sqlalchemy.org/) or [databases](https://pypi.org/project/databases/) over simple [BaseStorage](aioauth/storage.py) interface. diff --git a/docs/source/conf.py b/docs/source/conf.py index 426224d..8a2c900 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,20 +1,22 @@ # -- Path setup -------------------------------------------------------------- from pathlib import Path +import tomllib +from aioauth import __version__ # Project root folder. root = Path(__file__).parent.parent.parent # Loads __version__ file. about = {} -with open(root / "aioauth" / "__version__.py", "r") as f: - exec(f.read(), about) +with open(root / "pyproject.toml", "rb") as f: + about = tomllib.load(f) # -- Project information ----------------------------------------------------- -project = about["__title__"] -author = about["__author__"] -release = about["__version__"] +project = about["project"]["description"] +author = about["project"]["authors"][0]["name"] +release = __version__ # -- General configuration --------------------------------------------------- diff --git a/docs/source/contents.rst b/docs/source/contents.rst index 9eb2ab2..92b125e 100644 --- a/docs/source/contents.rst +++ b/docs/source/contents.rst @@ -15,20 +15,6 @@ :glob: :maxdepth: 2 - sections/using/* - -.. toctree:: - :caption: Examples - :glob: - :maxdepth: 2 - - sections/examples/* - -.. toctree:: - :caption: Documentation - :glob: - :maxdepth: 3 - sections/documentation/* .. toctree:: diff --git a/docs/source/index.rst b/docs/source/index.rst index 1ccef20..aa73431 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,7 +33,7 @@ aioauth supports the following RFCs: Pages ----- -* `Github Project `_ +* `Github Project `_ * `Issues `_ * `Discussion `_ diff --git a/docs/source/sections/examples/aiohttp.rst b/docs/source/sections/examples/aiohttp.rst deleted file mode 100644 index fce8490..0000000 --- a/docs/source/sections/examples/aiohttp.rst +++ /dev/null @@ -1,2 +0,0 @@ -Aiohttp -======= diff --git a/docs/source/sections/examples/fastapi.rst b/docs/source/sections/examples/fastapi.rst deleted file mode 100644 index 2deda94..0000000 --- a/docs/source/sections/examples/fastapi.rst +++ /dev/null @@ -1,58 +0,0 @@ -FastAPI -======= - -Installing ----------- - -To install aioauth with FastAPI at the command line: - -.. code-block:: - - $ pip install aioauth[fastapi] - -Usage example - -.. code-block:: python - - from dataclasses import dataclasses - from aioauth_fastapi.router import get_oauth2_router - from aioauth.storage import BaseStorage - from aioauth.requests import BaseRequest, Query, Post - from aioauth.models import AuthorizationCode, Client, Token - from aioauth.config import Settings - from aioauth.server import AuthorizationServer - from fastapi import FastAPI - - app = FastAPI() - - @dataclasses - class User: - """Custom user model""" - first_name: str - last_name: str - - - class Request(BaseRequest[Query, Post, User]): - """Custom Request model""" - - - class Storage(BaseStorage[Token, Client, AuthorizationCode, Request]): - """ - Storage methods must be implemented here. - """ - - storage = Storage() - authorization_server = AuthorizationServer[Request, Storage](storage) - - # NOTE: Redefinition of the default aioauth settings - # INSECURE_TRANSPORT must be enabled for local development only! - settings = Settings( - INSECURE_TRANSPORT=True, - ) - - # Include FastAPI router with oauth2 endpoints. - app.include_router( - get_oauth2_router(authorization_server, settings), - prefix="/oauth2", - tags=["oauth2"], - ) diff --git a/docs/source/sections/using/configuration.rst b/docs/source/sections/using/configuration.rst deleted file mode 100644 index b592e7a..0000000 --- a/docs/source/sections/using/configuration.rst +++ /dev/null @@ -1,48 +0,0 @@ -Configuration -============= - -All aioauth settings are made through :py:class:`aioauth.config.Settings` class. - -Defaults - -+----------------------------------------+---------------+----------------------------------------------------------------+ -| Setting | Default value | Description | -| | | | -+========================================+===============+================================================================+ -| TOKEN_EXPIRES_IN | 86400 | Access token lifetime. | -+----------------------------------------+---------------+----------------------------------------------------------------+ -| AUTHORIZATION_CODE_EXPIRES_IN | 300 | Authorization code lifetime. | -+----------------------------------------+---------------+----------------------------------------------------------------+ -| INSECURE_TRANSPORT | False | Allow connections over SSL only. When this option is disabled | -| | | server will raise "HTTP method is not allowed" error. | -+----------------------------------------+---------------+----------------------------------------------------------------+ - -the default settings can be changed as follows: - -.. code-block:: python - - import os - from aioauth.config import Settings - - settings = Settings( - INSECURE_TRANSPORT=not os.getenv('DEBUG', False) - ) - -this example disables checking for insecure transport, depending on the debug mode of the current environment. - -The :py:class:`aioauth.requests.Request` consumes an instance of the :py:class:`aioauth.config.Settings` class: - -.. code-block:: python - - import os - from aioauth.config import Settings - from aioauth.requests import Request - - settings = Settings( - INSECURE_TRANSPORT=not os.getenv('DEBUG', False) - ) - - request = Request( - settings=settings, - ... - ) diff --git a/docs/source/sections/using/server_database.rst b/docs/source/sections/using/server_database.rst deleted file mode 100644 index 30f7304..0000000 --- a/docs/source/sections/using/server_database.rst +++ /dev/null @@ -1,2 +0,0 @@ -Server & Database -================= From 805cf3c3d6494d007b3715eeee5bd6d136fb11f7 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 02:40:47 +0400 Subject: [PATCH 43/57] upd: codecov badge url --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d2a7a8b..8f5abea 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ## Asynchronous OAuth 2.0 framework for Python 3 [![Build Status](https://github.com/aliev/aioauth/workflows/CI/badge.svg?branch=master)](https://github.com/aliev/aioauth/actions/workflows/ci.yml?query=branch%3Amaster) -[![Coverage](https://badgen.net/codecov/c/github/aliev/aioauth)](https://app.codecov.io/gh/aliev/aioauth) +[![codecov](https://codecov.io/gh/aliev/aioauth/graph/badge.svg?token=NREOWPB586)](https://codecov.io/gh/aliev/aioauth) [![License](https://img.shields.io/github/license/aliev/aioauth)](https://github.com/aliev/aioauth/blob/master/LICENSE) [![PyPi](https://badgen.net/pypi/v/aioauth)](https://pypi.org/project/aioauth/) [![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-390/) From 11e519f5f4291c4711835e10e2d26e757d8f863f Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 02:51:50 +0400 Subject: [PATCH 44/57] upd: tests configuration --- Makefile | 3 +-- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 6c858f1..2a2f8fb 100644 --- a/Makefile +++ b/Makefile @@ -51,8 +51,7 @@ lint: ## check style with flake8 pre-commit run --all-files test: ## run tests quickly with the default Python - coverage run -m pytest tests - coverage xml -o junit.xml + pytest --cov --junitxml=junit.xml -o junit_family=legacy release: dist ## package and upload a release twine upload dist/* diff --git a/pyproject.toml b/pyproject.toml index 5ee7304..c1a81e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev = [ "mypy", "bandit", "pre-commit", + "pytest-cov", ] docs = [ From e380fe258ec5b10ca92150d8053d4d2dbc69db80 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 16:02:44 +0400 Subject: [PATCH 45/57] feat: remove user --- aioauth/errors.py | 38 ++++++++-------- aioauth/grant_type.py | 61 +++++++++++++------------- aioauth/models.py | 27 +++--------- aioauth/oidc/core/grant_type.py | 5 +-- aioauth/oidc/core/requests.py | 3 +- aioauth/requests.py | 6 +-- aioauth/response_type.py | 50 ++++++++++----------- aioauth/server.py | 77 +++++++++++++++------------------ aioauth/storage.py | 44 +++++++++---------- aioauth/types.py | 9 +--- examples/fastapi_example.py | 2 +- examples/shared/__init__.py | 5 ++- examples/shared/storage.py | 35 ++++++++------- tests/classes.py | 35 +++++++-------- tests/conftest.py | 4 +- tests/factories.py | 26 +++++------ tests/oidc/core/test_flow.py | 5 +-- tests/test_db.py | 5 +-- tests/test_endpoint.py | 22 +++++----- tests/test_flow.py | 17 ++------ tests/test_grant_type.py | 4 +- tests/test_request_validator.py | 10 ++--- 22 files changed, 214 insertions(+), 276 deletions(-) diff --git a/aioauth/errors.py b/aioauth/errors.py index b3f4041..15ec982 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -9,16 +9,16 @@ """ from http import HTTPStatus -from typing import Generic, Optional +from typing import Optional from urllib.parse import urljoin from .requests import Request from .collections import HTTPHeaderDict from .constances import default_headers -from .types import ErrorType, UserType +from .types import ErrorType -class OAuth2Error(Generic[UserType], Exception): +class OAuth2Error(Exception): """Base exception that all other exceptions inherit from.""" error: ErrorType @@ -30,7 +30,7 @@ class OAuth2Error(Generic[UserType], Exception): def __init__( self, - request: Request[UserType], + request: Request, description: Optional[str] = None, headers: Optional[HTTPHeaderDict] = None, state: Optional[str] = None, @@ -52,7 +52,7 @@ def __init__( super().__init__(f"({self.error}) {self.description}") -class MethodNotAllowedError(Generic[UserType], OAuth2Error[UserType]): +class MethodNotAllowedError(OAuth2Error): """ The request is valid, but the method trying to be accessed is not available to the resource owner. @@ -63,7 +63,7 @@ class MethodNotAllowedError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "method_is_not_allowed" -class InvalidRequestError(Generic[UserType], OAuth2Error[UserType]): +class InvalidRequestError(OAuth2Error): """ The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is @@ -73,7 +73,7 @@ class InvalidRequestError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "invalid_request" -class InvalidClientError(Generic[UserType], OAuth2Error[UserType]): +class InvalidClientError(OAuth2Error): """ Client authentication failed (e.g. unknown client, no client authentication included, or unsupported authentication method). @@ -108,14 +108,14 @@ def __init__( self.headers["WWW-Authenticate"] = "Basic " + ", ".join(auth_values) -class InsecureTransportError(Generic[UserType], OAuth2Error[UserType]): +class InsecureTransportError(OAuth2Error): """An exception will be thrown if the current request is not secure.""" description = "OAuth 2 MUST utilize https." error: ErrorType = "insecure_transport" -class UnsupportedGrantTypeError(Generic[UserType], OAuth2Error[UserType]): +class UnsupportedGrantTypeError(OAuth2Error): """ The authorization grant type is not supported by the authorization server. @@ -124,7 +124,7 @@ class UnsupportedGrantTypeError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "unsupported_grant_type" -class UnsupportedResponseTypeError(Generic[UserType], OAuth2Error[UserType]): +class UnsupportedResponseTypeError(OAuth2Error): """ The authorization server does not support obtaining an authorization code using this method. @@ -133,7 +133,7 @@ class UnsupportedResponseTypeError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "unsupported_response_type" -class InvalidGrantError(Generic[UserType], OAuth2Error[UserType]): +class InvalidGrantError(OAuth2Error): """ The provided authorization grant (e.g. authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does @@ -146,14 +146,14 @@ class InvalidGrantError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "invalid_grant" -class MismatchingStateError(Generic[UserType], OAuth2Error[UserType]): +class MismatchingStateError(OAuth2Error): """Unable to securely verify the integrity of the request and response.""" description = "CSRF Warning! State not equal in request and response." error: ErrorType = "mismatching_state" -class UnauthorizedClientError(Generic[UserType], OAuth2Error[UserType]): +class UnauthorizedClientError(OAuth2Error): """ The authenticated client is not authorized to use this authorization grant type. @@ -162,7 +162,7 @@ class UnauthorizedClientError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "unauthorized_client" -class InvalidScopeError(Generic[UserType], OAuth2Error[UserType]): +class InvalidScopeError(OAuth2Error): """ The requested scope is invalid, unknown, or malformed, or exceeds the scope granted by the resource owner. @@ -173,7 +173,7 @@ class InvalidScopeError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "invalid_scope" -class ServerError(Generic[UserType], OAuth2Error[UserType]): +class ServerError(OAuth2Error): """ The authorization server encountered an unexpected condition that prevented it from fulfilling the request. (This error code is needed @@ -185,7 +185,7 @@ class ServerError(Generic[UserType], OAuth2Error[UserType]): status_code: HTTPStatus = HTTPStatus.BAD_REQUEST -class TemporarilyUnavailableError(Generic[UserType], OAuth2Error[UserType]): +class TemporarilyUnavailableError(OAuth2Error): """ The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server. @@ -196,7 +196,7 @@ class TemporarilyUnavailableError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "temporarily_unavailable" -class InvalidRedirectURIError(Generic[UserType], OAuth2Error[UserType]): +class InvalidRedirectURIError(OAuth2Error): """ The requested redirect URI is missing or not allowed. """ @@ -204,7 +204,7 @@ class InvalidRedirectURIError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "invalid_request" -class UnsupportedTokenTypeError(Generic[UserType], OAuth2Error[UserType]): +class UnsupportedTokenTypeError(OAuth2Error): """ The authorization server does not support the revocation of the presented token type. That is, the client tried to revoke an access token on a server @@ -214,7 +214,7 @@ class UnsupportedTokenTypeError(Generic[UserType], OAuth2Error[UserType]): error: ErrorType = "unsupported_token_type" -class AccessDeniedError(Generic[UserType], OAuth2Error[UserType]): +class AccessDeniedError(OAuth2Error): """ The resource owner or authorization server denied the request """ diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 996ce49..49fe43e 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -8,10 +8,9 @@ ---- """ -from typing import Generic, Optional +from typing import Optional from .requests import Request -from .types import UserType from .storage import BaseStorage from .errors import ( InvalidClientError, @@ -27,12 +26,12 @@ from .utils import enforce_list, enforce_str, generate_token -class GrantTypeBase(Generic[UserType]): +class GrantTypeBase: """Base grant type that all other grant types inherit from.""" def __init__( self, - storage: BaseStorage[UserType], + storage: BaseStorage, client_id: str, client_secret: Optional[str], ): @@ -42,7 +41,7 @@ def __init__( self.scope: Optional[str] = None async def create_token_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> TokenResponse: """Creates token response to reply to client.""" if self.scope is None: @@ -65,28 +64,28 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: """Validates the client request to ensure it is valid.""" client = await self.storage.get_client( request=request, client_id=self.client_id, client_secret=self.client_secret ) if not client: - raise InvalidClientError[UserType]( + raise InvalidClientError( request=request, description="Invalid client_id parameter value." ) if not client.check_grant_type(request.post.grant_type): - raise UnauthorizedClientError[UserType](request=request) + raise UnauthorizedClientError(request=request) if not client.check_scope(request.post.scope): - raise InvalidScopeError[UserType](request=request) + raise InvalidScopeError(request=request) self.scope = request.post.scope return client -class AuthorizationCodeGrantType(GrantTypeBase[UserType]): +class AuthorizationCodeGrantType(GrantTypeBase): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -102,21 +101,21 @@ class AuthorizationCodeGrantType(GrantTypeBase[UserType]): See `RFC 6749 section 1.3.1 `_. """ - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.redirect_uri: - raise InvalidRedirectURIError[UserType]( + raise InvalidRedirectURIError( request=request, description="Mismatching redirect URI." ) if not client.check_redirect_uri(request.post.redirect_uri): - raise InvalidRedirectURIError[UserType]( + raise InvalidRedirectURIError( request=request, description="Invalid redirect URI." ) if not request.post.code: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Missing code parameter." ) @@ -125,14 +124,14 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not authorization_code: - raise InvalidGrantError[UserType](request=request) + raise InvalidGrantError(request=request) if ( authorization_code.code_challenge and authorization_code.code_challenge_method ): if not request.post.code_verifier: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Code verifier required." ) @@ -140,16 +139,16 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] request.post.code_verifier ) if not is_valid_code_challenge: - raise MismatchingStateError[UserType](request=request) + raise MismatchingStateError(request=request) if authorization_code.is_expired: - raise InvalidGrantError[UserType](request=request) + raise InvalidGrantError(request=request) self.scope = authorization_code.scope return client async def create_token_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> TokenResponse: token_response = await super().create_token_response(request, client) @@ -165,7 +164,7 @@ async def create_token_response( return token_response -class PasswordGrantType(GrantTypeBase[UserType]): +class PasswordGrantType(GrantTypeBase): """ The Password grant type is a way to exchange a user's credentials for an access token. Because the client application has to collect @@ -176,25 +175,25 @@ class PasswordGrantType(GrantTypeBase[UserType]): disallows the password grant entirely. """ - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.username or not request.post.password: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Invalid credentials given." ) user = await self.storage.get_user(request) if user is None: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Invalid credentials given." ) return client -class RefreshTokenGrantType(GrantTypeBase[UserType]): +class RefreshTokenGrantType(GrantTypeBase): """ The Refresh Token grant type is used by clients to exchange a refresh token for an access token when the access token has expired. @@ -204,7 +203,7 @@ class RefreshTokenGrantType(GrantTypeBase[UserType]): """ async def create_token_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> TokenResponse: """Validate token request and create token response.""" old_token = await self.storage.get_token( @@ -216,7 +215,7 @@ async def create_token_response( ) if not old_token or old_token.revoked or old_token.refresh_token_expired: - raise InvalidGrantError[UserType](request=request) + raise InvalidGrantError(request=request) # Revoke old token await self.storage.revoke_token( @@ -256,18 +255,18 @@ async def create_token_response( token_type=token.token_type, ) - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.refresh_token: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Missing refresh token parameter." ) return client -class ClientCredentialsGrantType(GrantTypeBase[UserType]): +class ClientCredentialsGrantType(GrantTypeBase): """ The Client Credentials grant type is used by clients to obtain an access token outside of the context of a user. This is typically @@ -276,9 +275,9 @@ class ClientCredentialsGrantType(GrantTypeBase[UserType]): See `RFC 6749 section 4.4 `_. """ - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: # client_credentials grant requires a client_secret if self.client_secret is None: - raise InvalidClientError[UserType](request) + raise InvalidClientError(request) return await super().validate_request(request) diff --git a/aioauth/models.py b/aioauth/models.py index da685b9..20ac30e 100644 --- a/aioauth/models.py +++ b/aioauth/models.py @@ -10,14 +10,14 @@ from dataclasses import dataclass import time -from typing import Generic, List, Optional, Union +from typing import List, Optional, Union -from .types import CodeChallengeMethod, GrantType, ResponseType, TokenType, UserType +from .types import CodeChallengeMethod, GrantType, ResponseType, TokenType from .utils import create_s256_code_challenge, enforce_list, enforce_str @dataclass -class Client(Generic[UserType]): +class Client: """OAuth2.0 client model object.""" client_id: str @@ -63,13 +63,6 @@ class Client(Generic[UserType]): scopes granted. """ - user: Optional[UserType] = None - """ - The user who is the creator of the Client. - This optional attribute can be useful if you are creating a server that - can be managed by multiple users. - """ - def check_redirect_uri(self, redirect_uri) -> bool: """ Verifies passed ``redirect_uri`` is part of the Clients's @@ -113,7 +106,7 @@ def check_scope(self, scope: str) -> bool: @dataclass -class AuthorizationCode(Generic[UserType]): +class AuthorizationCode: code: str """ Authorization code that the client previously received from the @@ -185,11 +178,6 @@ class AuthorizationCode(Generic[UserType]): Random piece of data. """ - user: Optional[UserType] = None - """ - The user who owns the AuthorizationCode. - """ - def check_code_challenge(self, code_verifier: str) -> bool: is_valid_code_challenge = False @@ -212,7 +200,7 @@ def is_expired(self) -> bool: @dataclass -class Token(Generic[UserType]): +class Token: access_token: str """ Token that clients use to make API requests on behalf of the @@ -265,11 +253,6 @@ class Token(Generic[UserType]): Flag that indicates whether or not the token has been revoked. """ - user: Optional[UserType] = None - """ - The user who owns the Token. - """ - @property def is_expired(self) -> bool: """Checks if the token has expired.""" diff --git a/aioauth/oidc/core/grant_type.py b/aioauth/oidc/core/grant_type.py index 3828658..1a5b081 100644 --- a/aioauth/oidc/core/grant_type.py +++ b/aioauth/oidc/core/grant_type.py @@ -16,11 +16,10 @@ from ...models import Client from ...oidc.core.responses import TokenResponse from ...requests import Request -from ...types import UserType from ...utils import generate_token -class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[UserType]): +class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType): """ The Authorization Code grant type is used by confidential and public clients to exchange an authorization code for an access token. After @@ -37,7 +36,7 @@ class AuthorizationCodeGrantType(OAuth2AuthorizationCodeGrantType[UserType]): """ async def create_token_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> TokenResponse: """ Creates token response to reply to client. diff --git a/aioauth/oidc/core/requests.py b/aioauth/oidc/core/requests.py index 8d58234..7d67e78 100644 --- a/aioauth/oidc/core/requests.py +++ b/aioauth/oidc/core/requests.py @@ -6,7 +6,6 @@ Request as BaseRequest, Query as BaseQuery, ) -from ...types import UserType @dataclass @@ -20,7 +19,7 @@ class Query(BaseQuery): @dataclass -class Request(BaseRequest[UserType]): +class Request(BaseRequest): """Object that contains a client's complete request.""" query: Query = field(default_factory=Query) diff --git a/aioauth/requests.py b/aioauth/requests.py index 238adf9..d7ede89 100644 --- a/aioauth/requests.py +++ b/aioauth/requests.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass, field -from typing import Generic, Optional +from typing import Optional from .collections import HTTPHeaderDict from .config import Settings @@ -19,7 +19,6 @@ RequestMethod, ResponseMode, TokenType, - UserType, ) @@ -63,7 +62,7 @@ class Post: @dataclass -class Request(Generic[UserType]): +class Request: """Object that contains a client's complete request.""" method: RequestMethod @@ -71,6 +70,5 @@ class Request(Generic[UserType]): post: Post = field(default_factory=Post) headers: HTTPHeaderDict = field(default_factory=HTTPHeaderDict) url: str = "" - user: Optional[UserType] = None settings: Settings = field(default_factory=Settings) extra: dict = field(default_factory=dict) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 6bd1063..010f76b 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -8,10 +8,9 @@ ---- """ -from typing import Generic, Tuple, get_args +from typing import Tuple, get_args from .requests import Request -from .types import UserType from .storage import BaseStorage from .utils import generate_token @@ -32,13 +31,13 @@ from .types import CodeChallengeMethod -class ResponseTypeBase(Generic[UserType]): +class ResponseTypeBase: """Base response type that all other exceptions inherit from.""" - def __init__(self, storage: BaseStorage[UserType]): + def __init__(self, storage: BaseStorage): self.storage = storage - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: + async def validate_request(self, request: Request) -> Client: state = request.query.state code_challenge_methods: Tuple[CodeChallengeMethod, ...] = get_args( @@ -46,7 +45,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not request.query.client_id: - raise InvalidClientError[UserType]( + raise InvalidClientError( request=request, description="Missing client_id parameter.", state=state ) @@ -55,54 +54,49 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] ) if not client: - raise InvalidClientError[UserType]( + raise InvalidClientError( request=request, description="Invalid client_id parameter value.", state=state, ) if not request.query.redirect_uri: - raise InvalidRedirectURIError[UserType]( + raise InvalidRedirectURIError( request=request, description="Mismatching redirect URI.", state=state ) if not client.check_redirect_uri(request.query.redirect_uri): - raise InvalidRedirectURIError[UserType]( + raise InvalidRedirectURIError( request=request, description="Invalid redirect URI.", state=state ) if request.query.code_challenge_method: if request.query.code_challenge_method not in code_challenge_methods: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Transform algorithm not supported.", state=state, ) if not request.query.code_challenge: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Code challenge required.", state=state ) if not client.check_response_type(request.query.response_type): - raise UnsupportedResponseTypeError[UserType](request=request, state=state) + raise UnsupportedResponseTypeError(request=request, state=state) if not client.check_scope(request.query.scope): - raise InvalidScopeError[UserType](request=request, state=state) - - if not request.user: - raise InvalidClientError[UserType]( - request=request, description="User is not authorized", state=state - ) + raise InvalidScopeError(request=request, state=state) return client -class ResponseTypeToken(ResponseTypeBase[UserType]): +class ResponseTypeToken(ResponseTypeBase): """Response type that contains a token.""" async def create_authorization_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> TokenResponse: token = await self.storage.create_token( request=request, @@ -132,11 +126,11 @@ async def create_authorization_response( ) -class ResponseTypeAuthorizationCode(ResponseTypeBase[UserType]): +class ResponseTypeAuthorizationCode(ResponseTypeBase): """Response type that contains an authorization code.""" async def create_authorization_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> AuthorizationCodeResponse: authorization_code = await self.storage.create_authorization_code( client_id=client.client_id, @@ -155,13 +149,13 @@ async def create_authorization_response( ) -class ResponseTypeIdToken(ResponseTypeBase[UserType]): - async def validate_request(self, request: Request[UserType]) -> Client[UserType]: +class ResponseTypeIdToken(ResponseTypeBase): + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) # nonce is required for id_token if not request.query.nonce: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Nonce required for response_type id_token.", state=request.query.state, @@ -169,7 +163,7 @@ async def validate_request(self, request: Request[UserType]) -> Client[UserType] return client async def create_authorization_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> IdTokenResponse: id_token = await self.storage.get_id_token( request=request, @@ -183,8 +177,8 @@ async def create_authorization_response( return IdTokenResponse(id_token=id_token) -class ResponseTypeNone(ResponseTypeBase[UserType]): +class ResponseTypeNone(ResponseTypeBase): async def create_authorization_response( - self, request: Request[UserType], client: Client[UserType] + self, request: Request, client: Client ) -> NoneResponse: return NoneResponse() diff --git a/aioauth/server.py b/aioauth/server.py index 0f133b3..de1a4f1 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -19,10 +19,9 @@ from dataclasses import asdict from http import HTTPStatus -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union, get_args +from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args from .requests import Request -from .types import UserType from .storage import BaseStorage @@ -71,25 +70,25 @@ ) -class AuthorizationServer(Generic[UserType]): +class AuthorizationServer: """Interface for initializing an OAuth 2.0 server.""" response_types: Dict[ResponseType, Any] = { - "token": ResponseTypeToken[UserType], - "code": ResponseTypeAuthorizationCode[UserType], - "none": ResponseTypeNone[UserType], - "id_token": ResponseTypeIdToken[UserType], + "token": ResponseTypeToken, + "code": ResponseTypeAuthorizationCode, + "none": ResponseTypeNone, + "id_token": ResponseTypeIdToken, } grant_types: Dict[GrantType, Any] = { - "authorization_code": AuthorizationCodeGrantType[UserType], - "client_credentials": ClientCredentialsGrantType[UserType], - "password": PasswordGrantType[UserType], - "refresh_token": RefreshTokenGrantType[UserType], + "authorization_code": AuthorizationCodeGrantType, + "client_credentials": ClientCredentialsGrantType, + "password": PasswordGrantType, + "refresh_token": RefreshTokenGrantType, } def __init__( self, - storage: BaseStorage[UserType], + storage: BaseStorage, response_types: Optional[Dict] = None, grant_types: Optional[Dict] = None, ): @@ -101,7 +100,7 @@ def __init__( if grant_types is not None: self.grant_types = grant_types - def is_secure_transport(self, request: Request[UserType]) -> bool: + def is_secure_transport(self, request: Request) -> bool: """ Verifies the request was sent via a protected SSL tunnel. @@ -118,25 +117,21 @@ def is_secure_transport(self, request: Request[UserType]) -> bool: return True return request.url.lower().startswith("https://") - def validate_request( - self, request: Request[UserType], allowed_methods: List[RequestMethod] - ): + def validate_request(self, request: Request, allowed_methods: List[RequestMethod]): if not request.settings.AVAILABLE: - raise TemporarilyUnavailableError[UserType](request=request) + raise TemporarilyUnavailableError(request=request) if not self.is_secure_transport(request): - raise InsecureTransportError[UserType](request=request) + raise InsecureTransportError(request=request) if request.method not in allowed_methods: headers = HTTPHeaderDict( {**default_headers, "allow": ", ".join(allowed_methods)} ) - raise MethodNotAllowedError[UserType](request=request, headers=headers) + raise MethodNotAllowedError(request=request, headers=headers) @catch_errors_and_unavailability() - async def create_token_introspection_response( - self, request: Request[UserType] - ) -> Response: + async def create_token_introspection_response(self, request: Request) -> Response: """ Returns a response object with introspection of the passed token. For more information see `RFC7662 section 2.1 `_. @@ -177,7 +172,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError[UserType](request) + raise InvalidClientError(request) token_types: Tuple[TokenType, ...] = get_args(TokenType) token_type: TokenType = "refresh_token" @@ -221,7 +216,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response: ) def get_client_credentials( - self, request: Request[UserType], secret_required: bool + self, request: Request, secret_required: bool ) -> Tuple[str, str]: client_id = request.post.client_id client_secret = request.post.client_secret @@ -236,7 +231,7 @@ def get_client_credentials( if client_id is None or secret_required: # Either we didn't find a client ID at all, or we found # a client ID but no secret and a secret is required. - raise InvalidClientError[UserType]( + raise InvalidClientError( description="Invalid client_id parameter value.", request=request, ) from exc @@ -249,7 +244,7 @@ def get_client_credentials( return client_id, client_secret @catch_errors_and_unavailability() - async def create_token_response(self, request: Request[UserType]) -> Response: + async def create_token_response(self, request: Request) -> Response: """Endpoint to obtain an access and/or ID token by presenting an authorization grant or refresh token. Validates a token request and creates a token response. @@ -301,17 +296,17 @@ async def token(request: fastapi.Request) -> fastapi.Response: if not request.post.grant_type: # grant_type request value is empty - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Request is missing grant type." ) GrantTypeClass: Type[ Union[ - GrantTypeBase[UserType], - AuthorizationCodeGrantType[UserType], - PasswordGrantType[UserType], - RefreshTokenGrantType[UserType], - ClientCredentialsGrantType[UserType], + GrantTypeBase, + AuthorizationCodeGrantType, + PasswordGrantType, + RefreshTokenGrantType, + ClientCredentialsGrantType, ] ] @@ -319,7 +314,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: GrantTypeClass = self.grant_types[request.post.grant_type] except KeyError as exc: # grant_type request value is invalid - raise UnsupportedGrantTypeError[UserType](request=request) from exc + raise UnsupportedGrantTypeError(request=request) from exc grant_type = GrantTypeClass( storage=self.storage, client_id=client_id, client_secret=client_secret @@ -341,9 +336,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: InvalidRedirectURIError, ) ) - async def create_authorization_response( - self, request: Request[UserType] - ) -> Response: + async def create_authorization_response(self, request: Request) -> Response: """ Endpoint to interact with the resource owner and obtain an authorization grant. @@ -397,7 +390,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: state = request.query.state if not response_type_list: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Missing response_type parameter.", state=state, @@ -412,7 +405,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: response_type_classes.add(ResponseTypeClass) if not response_type_classes: - raise UnsupportedResponseTypeError[UserType](request=request, state=state) + raise UnsupportedResponseTypeError(request=request, state=state) for ResponseTypeClass in response_type_classes: response_type = ResponseTypeClass(storage=self.storage) @@ -475,7 +468,7 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: ) @catch_errors_and_unavailability() - async def revoke_token(self, request: Request[UserType]) -> Response: + async def revoke_token(self, request: Request) -> Response: """Endpoint to revoke an access token or refresh token. For more information see `RFC7009 `_. @@ -515,10 +508,10 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: ) if not client: - raise InvalidClientError[UserType](request) + raise InvalidClientError(request) if not request.post.token: - raise InvalidRequestError[UserType]( + raise InvalidRequestError( request=request, description="Request is missing token." ) @@ -526,7 +519,7 @@ async def revoke(request: fastapi.Request) -> fastapi.Response: "refresh_token", "access_token", }: - raise UnsupportedTokenTypeError[UserType](request=request) + raise UnsupportedTokenTypeError(request=request) access_token = ( request.post.token diff --git a/aioauth/storage.py b/aioauth/storage.py index 6682626..61bd10e 100644 --- a/aioauth/storage.py +++ b/aioauth/storage.py @@ -10,20 +10,19 @@ ---- """ -from typing import Optional, Generic +from typing import Any, Optional from .models import AuthorizationCode, Client, Token from .types import CodeChallengeMethod, TokenType from .requests import Request -from .types import UserType -class TokenStorage(Generic[UserType]): +class TokenStorage: async def create_token( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, access_token: str, @@ -55,7 +54,7 @@ async def create_token( async def get_token( self, *, - request: Request[UserType], + request: Request, client_id: str, token_type: Optional[TokenType] = None, access_token: Optional[str] = None, @@ -80,7 +79,7 @@ async def get_token( async def revoke_token( self, *, - request: Request[UserType], + request: Request, client_id: str, refresh_token: Optional[str] = None, token_type: Optional[TokenType] = None, @@ -90,11 +89,11 @@ async def revoke_token( raise NotImplementedError -class AuthorizationCodeStorage(Generic[UserType]): +class AuthorizationCodeStorage: async def create_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, response_type: str, @@ -129,7 +128,7 @@ async def create_authorization_code( async def get_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, code: str, ) -> Optional[AuthorizationCode]: @@ -156,7 +155,7 @@ async def get_authorization_code( async def delete_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, code: str, ) -> None: @@ -175,14 +174,14 @@ async def delete_authorization_code( ) -class ClientStorage(Generic[UserType]): +class ClientStorage: async def get_client( self, *, - request: Request[UserType], + request: Request, client_id: str, client_secret: Optional[str] = None, - ) -> Optional[Client[UserType]]: + ) -> Optional[Client]: """Gets existing client from the database if it exists. Warning: @@ -202,8 +201,8 @@ async def get_client( raise NotImplementedError("Method get_client must be implemented") -class UserStorage(Generic[UserType]): - async def get_user(self, request: Request[UserType]) -> Optional[UserType]: +class UserStorage: + async def get_user(self, request: Request) -> Optional[Any]: """Returns a user. Note: @@ -218,11 +217,11 @@ async def get_user(self, request: Request[UserType]) -> Optional[UserType]: raise NotImplementedError("Method get_user must be implemented") -class IDTokenStorage(Generic[UserType]): +class IDTokenStorage: async def get_id_token( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, redirect_uri: str, @@ -240,10 +239,9 @@ async def get_id_token( class BaseStorage( - Generic[UserType], - TokenStorage[UserType], - AuthorizationCodeStorage[UserType], - ClientStorage[UserType], - UserStorage[UserType], - IDTokenStorage[UserType], + TokenStorage, + AuthorizationCodeStorage, + ClientStorage, + UserStorage, + IDTokenStorage, ): ... diff --git a/aioauth/types.py b/aioauth/types.py index 986fec4..d2eadbc 100644 --- a/aioauth/types.py +++ b/aioauth/types.py @@ -9,12 +9,7 @@ """ import sys -from typing import Any, Literal - -if sys.version_info >= (3, 13): - from typing import TypeVar -else: - from typing_extensions import TypeVar +from typing import Literal if sys.version_info >= (3, 11): from typing import TypeAlias @@ -73,5 +68,3 @@ TokenType: TypeAlias = Literal["access_token", "refresh_token", "Bearer"] - -UserType = TypeVar("UserType", default=Any) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 6bdc8d0..fc5aa67 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -51,7 +51,7 @@ async def to_request(request: Request) -> OAuthRequest: query=Query(**request.query_params), # type: ignore settings=settings, url=str(request.url), - user=user, + extra={"user": user}, ) diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py index 3af6b51..ec88b6e 100644 --- a/examples/shared/__init__.py +++ b/examples/shared/__init__.py @@ -13,7 +13,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from .config import load_config -from .storage import BackendStore, User +from .storage import BackendStore +from .models import User __all__ = [ "AuthServer", @@ -68,5 +69,5 @@ async def lifespan(*_): await engine.dispose() -class AuthServer(AuthorizationServer[User]): +class AuthServer(AuthorizationServer): pass diff --git a/examples/shared/storage.py b/examples/shared/storage.py index 00e3c8a..adfc50d 100644 --- a/examples/shared/storage.py +++ b/examples/shared/storage.py @@ -18,13 +18,12 @@ ) from aioauth.types import CodeChallengeMethod, TokenType -from .models import User from .models import Client as ClientTable from .models import AuthorizationCode as AuthCodeTable from .models import Token as TokenTable -class ClientStore(ClientStorage[User]): +class ClientStore(ClientStorage): def __init__(self, session: AsyncSession): self.session = session @@ -32,10 +31,10 @@ def __init__(self, session: AsyncSession): async def get_client( self, *, - request: Request[User], + request: Request, client_id: str, client_secret: Optional[str] = None, - ) -> Optional[Client[User]]: + ) -> Optional[Client]: """ """ sql = select(ClientTable).where(ClientTable.client_id == client_id) async with self.session: @@ -55,7 +54,7 @@ async def get_client( ) -class AuthCodeStore(AuthorizationCodeStorage[User]): +class AuthCodeStore(AuthorizationCodeStorage): def __init__(self, session: AsyncSession): self.session = session @@ -63,7 +62,7 @@ def __init__(self, session: AsyncSession): async def create_authorization_code( self, *, - request: Request[User], + request: Request, client_id: str, scope: str, response_type: str, @@ -74,6 +73,8 @@ async def create_authorization_code( nonce: Optional[str] = None, ) -> AuthorizationCode: """""" + user = request.extra.get("user", None) + auth_code = AuthorizationCode( code=code, client_id=client_id, @@ -84,7 +85,6 @@ async def create_authorization_code( expires_in=300, code_challenge=code_challenge, code_challenge_method=code_challenge_method, - user=request.user, ) record = AuthCodeTable( code=auth_code.code, @@ -97,7 +97,7 @@ async def create_authorization_code( code_challenge=auth_code.code_challenge, code_challenge_method=auth_code.code_challenge_method, nonce=auth_code.nonce, - user_id=request.user.id if request.user else None, + user_id=user.id if user is not None else None, ) async with self.session: self.session.add(record) @@ -107,7 +107,7 @@ async def create_authorization_code( async def get_authorization_code( self, *, - request: Request[User], + request: Request, client_id: str, code: str, ) -> Optional[AuthorizationCode]: @@ -132,7 +132,7 @@ async def get_authorization_code( async def delete_authorization_code( self, *, - request: Request[User], + request: Request, client_id: str, code: str, ) -> None: @@ -144,7 +144,7 @@ async def delete_authorization_code( await self.session.commit() -class TokenStore(TokenStorage[User]): +class TokenStore(TokenStorage): def __init__(self, session: AsyncSession): self.session = session @@ -152,13 +152,14 @@ def __init__(self, session: AsyncSession): async def create_token( self, *, - request: Request[User], + request: Request, client_id: str, scope: str, access_token: str, refresh_token: Optional[str] = None, ) -> Token: """ """ + user = request.extra.get("user", None) token = Token( client_id=client_id, access_token=access_token, @@ -167,7 +168,6 @@ async def create_token( issued_at=int(datetime.now(tz=timezone.utc).timestamp()), expires_in=300, refresh_token_expires_in=900, - user=request.user, ) record = TokenTable( client_id=token.client_id, @@ -179,7 +179,7 @@ async def create_token( refresh_token_expires_in=token.refresh_token_expires_in, token_type=token.token_type, revoked=token.revoked, - user_id=token.user.id if token.user else None, + user_id=user.id if user is not None else None, ) async with self.session: self.session.add(record) @@ -189,7 +189,7 @@ async def create_token( async def get_token( self, *, - request: Request[User], + request: Request, client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, @@ -213,13 +213,12 @@ async def get_token( issued_at=result.issued_at, expires_in=result.expires_in, refresh_token_expires_in=result.refresh_token_expires_in, - user=result.user, ) async def revoke_token( self, *, - request: Request[User], + request: Request, client_id: str, token_type: Optional[TokenType] = "refresh_token", access_token: Optional[str] = None, @@ -238,5 +237,5 @@ async def revoke_token( await self.session.commit() -class BackendStore(ClientStore, AuthCodeStore, TokenStore, BaseStorage[User]): +class BackendStore(ClientStore, AuthCodeStore, TokenStore, BaseStorage): pass diff --git a/tests/classes.py b/tests/classes.py index 4e75039..d3422f5 100644 --- a/tests/classes.py +++ b/tests/classes.py @@ -1,6 +1,6 @@ import time -from typing import Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type from functools import cached_property from dataclasses import replace, dataclass @@ -17,7 +17,6 @@ GrantType, ResponseType, TokenType, - UserType, ) @@ -26,7 +25,7 @@ class User: username: str -class Storage(BaseStorage[User]): +class Storage(BaseStorage): def __init__( self, authorization_codes: List[AuthorizationCode], @@ -52,10 +51,10 @@ def _get_by_client_id(self, client_id: str): async def get_client( self, *, - request: Request[UserType], + request: Request, client_id: str, client_secret: Optional[str] = None, - ) -> Optional[Client[User]]: + ) -> Optional[Client]: if client_secret is not None: return self._get_by_client_secret(client_id, client_secret) @@ -64,13 +63,13 @@ async def get_client( async def create_token( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, access_token: str, refresh_token: Optional[str] = None, ): - token: Token[User] = Token( + token: Token = Token( client_id=client_id, expires_in=request.settings.TOKEN_EXPIRES_IN, refresh_token_expires_in=request.settings.REFRESH_TOKEN_EXPIRES_IN, @@ -86,7 +85,7 @@ async def create_token( async def revoke_token( self, *, - request: Request[UserType], + request: Request, client_id: str, refresh_token: Optional[str] = None, token_type: Optional[TokenType] = None, @@ -102,7 +101,7 @@ async def revoke_token( async def get_token( self, *, - request: Request[UserType], + request: Request, client_id: str, token_type: Optional[TokenType] = None, access_token: Optional[str] = None, @@ -122,7 +121,7 @@ async def get_token( ): return token_ - async def get_user(self, request: Request[User]) -> Optional[User]: + async def get_user(self, request: Request) -> Any: password = request.post.password username = request.post.username @@ -137,7 +136,7 @@ async def get_user(self, request: Request[User]) -> Optional[User]: async def create_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, response_type: str, @@ -166,7 +165,7 @@ async def create_authorization_code( async def get_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, code: str, ) -> Optional[AuthorizationCode]: @@ -180,7 +179,7 @@ async def get_authorization_code( async def delete_authorization_code( self, *, - request: Request[UserType], + request: Request, client_id: str, code: str, ): @@ -195,7 +194,7 @@ async def delete_authorization_code( async def get_id_token( self, *, - request: Request[UserType], + request: Request, client_id: str, scope: str, redirect_uri: str, @@ -209,12 +208,10 @@ class AuthorizationContext: def __init__( self, clients: Optional[List[Client]] = None, - grant_types: Optional[Dict[GrantType, Type[GrantTypeBase[User]]]] = None, + grant_types: Optional[Dict[GrantType, Type[GrantTypeBase]]] = None, initial_authorization_codes: Optional[List[AuthorizationCode]] = None, initial_tokens: Optional[List[Token]] = None, - response_types: Optional[ - Dict[ResponseType, Type[ResponseTypeBase[User]]] - ] = None, + response_types: Optional[Dict[ResponseType, Type[ResponseTypeBase]]] = None, settings: Optional[Settings] = None, users: Optional[Dict[str, str]] = None, ): @@ -228,7 +225,7 @@ def __init__( self.users = users or {} @cached_property - def server(self) -> AuthorizationServer[User]: + def server(self) -> AuthorizationServer: return AuthorizationServer( grant_types=self.grant_types, response_types=self.response_types, diff --git a/tests/conftest.py b/tests/conftest.py index e032f22..d995e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from aioauth.server import AuthorizationServer from tests import factories -from tests.classes import AuthorizationContext, User +from tests.classes import AuthorizationContext @pytest.fixture @@ -20,5 +20,5 @@ def context() -> Generator[AuthorizationContext, Any, Any]: @pytest.fixture def server( context: AuthorizationContext, -) -> Generator[AuthorizationServer[User], Any, Any]: +) -> Generator[AuthorizationServer, Any, Any]: yield context.server diff --git a/tests/factories.py b/tests/factories.py index a458aec..e3c90b0 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -20,7 +20,7 @@ from aioauth.types import CodeChallengeMethod, GrantType, ResponseType from aioauth.utils import generate_token -from tests.classes import AuthorizationContext, User +from tests.classes import AuthorizationContext def access_token_factory() -> str: @@ -47,21 +47,21 @@ def auth_time_factory() -> int: return int(time.time()) -def grant_types_factory() -> Dict[GrantType, Type[GrantTypeBase[User]]]: +def grant_types_factory() -> Dict[GrantType, Type[GrantTypeBase]]: return { - "authorization_code": AuthorizationCodeGrantType[User], - "client_credentials": ClientCredentialsGrantType[User], - "password": PasswordGrantType[User], - "refresh_token": RefreshTokenGrantType[User], + "authorization_code": AuthorizationCodeGrantType, + "client_credentials": ClientCredentialsGrantType, + "password": PasswordGrantType, + "refresh_token": RefreshTokenGrantType, } -def response_types_factory() -> Dict[ResponseType, Type[ResponseTypeBase[User]]]: +def response_types_factory() -> Dict[ResponseType, Type[ResponseTypeBase]]: return { - "code": ResponseTypeAuthorizationCode[User], - "id_token": ResponseTypeIdToken[User], - "none": ResponseTypeNone[User], - "token": ResponseTypeToken[User], + "code": ResponseTypeAuthorizationCode, + "id_token": ResponseTypeIdToken, + "none": ResponseTypeNone, + "token": ResponseTypeToken, } @@ -139,10 +139,10 @@ def token_factory( def context_factory( clients: Optional[List[Client]] = None, - grant_types: Optional[Dict[GrantType, Type[GrantTypeBase[User]]]] = None, + grant_types: Optional[Dict[GrantType, Type[GrantTypeBase]]] = None, initial_authorization_codes: Optional[List[AuthorizationCode]] = None, initial_tokens: Optional[List[Token]] = None, - response_types: Optional[Dict[ResponseType, Type[ResponseTypeBase[User]]]] = None, + response_types: Optional[Dict[ResponseType, Type[ResponseTypeBase]]] = None, settings: Optional[Settings] = None, users: Optional[Dict[str, str]] = None, ) -> AuthorizationContext: diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py index dedf556..6ef920c 100644 --- a/tests/oidc/core/test_flow.py +++ b/tests/oidc/core/test_flow.py @@ -17,7 +17,7 @@ "user, expected_status_code", [ ("username", HTTPStatus.FOUND), - (None, HTTPStatus.UNAUTHORIZED), + (None, HTTPStatus.FOUND), ], ) async def test_authorization_endpoint_allows_prompt_query_param( @@ -43,11 +43,10 @@ async def test_authorization_endpoint_allows_prompt_query_param( state=generate_token(10), ) - request = Request[User]( + request = Request( url=request_url, query=query, method="GET", - user=user, ) await check_request_validators(request, server.create_authorization_response) diff --git a/tests/test_db.py b/tests/test_db.py index d4e5940..95475bb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,4 +1,3 @@ -from typing import Any import pytest from aioauth.models import AuthorizationCode, Client, Token @@ -10,8 +9,8 @@ @pytest.mark.asyncio async def test_storage_class() -> None: - db = BaseStorage[Any]() - request = Request[Any](method="POST") + db = BaseStorage() + request = Request(method="POST") client: Client = factories.client_factory() token: Token = factories.token_factory() authorization_code: AuthorizationCode = factories.authorization_code_factory() diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index e44e8f4..21a9866 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -12,7 +12,7 @@ ) from tests import factories -from tests.classes import AuthorizationContext, User +from tests.classes import AuthorizationContext @pytest.mark.asyncio @@ -43,7 +43,7 @@ async def test_invalid_token(context: AuthorizationContext): token = "invalid token" post = Post(token=token) - request = Request[User]( + request = Request( url=request_url, post=post, method="POST", @@ -99,7 +99,7 @@ async def test_valid_token(context: AuthorizationContext): server = context.server post = Post(token=token.refresh_token) - request = Request[User]( + request = Request( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -128,7 +128,7 @@ async def test_introspect_revoked_token(context: AuthorizationContext): grant_type="refresh_token", refresh_token=token.refresh_token, ) - request = Request[User]( + request = Request( settings=settings, url=request_url, post=post, @@ -139,7 +139,7 @@ async def test_introspect_revoked_token(context: AuthorizationContext): # Check that refreshed token was revoked post = Post(token=token.access_token, token_type_hint="access_token") - request = Request[User]( + request = Request( settings=settings, post=post, method="POST", @@ -171,7 +171,7 @@ async def test_introspect_token_with_wrong_client_secret(context: AuthorizationC server = context.server post = Post(token=token.refresh_token) - request = Request[User]( + request = Request( post=post, method="POST", headers=encode_auth_headers(client_id, f"not {client_secret}"), @@ -216,7 +216,7 @@ async def test_revoke_refresh_token(context: AuthorizationContext): server = context.server post = Post(token=token.refresh_token, token_type_hint="refresh_token") - request = Request[User]( + request = Request( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -227,7 +227,7 @@ async def test_revoke_refresh_token(context: AuthorizationContext): assert response.status_code == HTTPStatus.NO_CONTENT # Check that the token was revoked - request = Request[User]( + request = Request( settings=settings, post=post, method="POST", @@ -248,7 +248,7 @@ async def test_revoke_access_token(context: AuthorizationContext): server = context.server post = Post(token=token.access_token, token_type_hint="access_token") - request = Request[User]( + request = Request( post=post, method="POST", headers=encode_auth_headers(client_id, client_secret), @@ -259,7 +259,7 @@ async def test_revoke_access_token(context: AuthorizationContext): assert response.status_code == HTTPStatus.NO_CONTENT # Check that the token was revoked - request = Request[User]( + request = Request( settings=settings, post=post, method="POST", @@ -307,7 +307,7 @@ async def test_revoke_access_token_with_wrong_client_secret( server = context.server post = Post(token=token.access_token, token_type_hint="access_token") - request = Request[User]( + request = Request( post=post, method="POST", headers=encode_auth_headers(client_id, f"not {client_secret}"), diff --git a/tests/test_flow.py b/tests/test_flow.py index 4b138df..b99d3e8 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -14,7 +14,7 @@ ) from tests import factories -from tests.classes import AuthorizationContext, User +from tests.classes import AuthorizationContext from tests.utils import check_request_validators @@ -44,11 +44,10 @@ async def test_authorization_code_flow_plain_code_challenge(): scope=scope, ) - request = Request[User]( + request = Request( url=request_url, query=query, method="GET", - user=User(username="A"), ) await check_request_validators(request, server.create_authorization_response) @@ -162,7 +161,6 @@ async def test_authorization_code_flow_pkce_code_challenge(): code_challenge = create_s256_code_challenge(code_verifier) redirect_uri = client.redirect_uris[0] request_url = "https://localhost" - user = "username" state = generate_token(10) query = Query( @@ -179,7 +177,6 @@ async def test_authorization_code_flow_pkce_code_challenge(): url=request_url, query=query, method="GET", - user=user, ) response = await server.create_authorization_response(request) assert response.status_code == HTTPStatus.FOUND @@ -255,7 +252,6 @@ async def test_implicit_flow(context_factory, settings): url=request_url, query=query, method="GET", - user=username, settings=context.settings, ) @@ -385,7 +381,6 @@ async def test_authorization_code_flow(): url=request_url, query=query, method="GET", - user=username, ) await check_request_validators(request, server.create_authorization_response) @@ -442,7 +437,6 @@ async def test_authorization_code_flow_credentials_in_post(): url=request_url, query=query, method="GET", - user=username, ) await check_request_validators(request, server.create_authorization_response) @@ -485,7 +479,7 @@ async def test_client_credentials_flow_post_data(context: AuthorizationContext): scope=client.scope, ) - request = Request[User](url=request_url, post=post, method="POST") + request = Request(url=request_url, post=post, method="POST") await check_request_validators(request, server.create_token_response) @@ -504,7 +498,7 @@ async def test_client_credentials_flow_auth_header(context: AuthorizationContext scope=client.scope, ) - request = Request[User]( + request = Request( url=request_url, post=post, method="POST", @@ -547,7 +541,6 @@ async def test_multiple_response_types(context_factory, settings): url=request_url, query=query, method="GET", - user=username, settings=context.settings, ) @@ -593,7 +586,6 @@ async def test_response_type_none(context_factory): url=request_url, query=query, method="GET", - user=username, ) await check_request_validators(request, server.create_authorization_response) @@ -647,7 +639,6 @@ async def test_response_type_id_token(context_factory, response_mode, settings): url=request_url, query=query, method="GET", - user=username, settings=context.settings, ) diff --git a/tests/test_grant_type.py b/tests/test_grant_type.py index a6871bb..c198968 100644 --- a/tests/test_grant_type.py +++ b/tests/test_grant_type.py @@ -5,8 +5,6 @@ from aioauth.requests import Post, Request from aioauth.utils import encode_auth_headers -from tests.classes import User - @pytest.mark.asyncio async def test_refresh_token_grant_type(context): @@ -32,7 +30,7 @@ async def test_refresh_token_grant_type(context): headers=encode_auth_headers(client_id, client_secret), ) - grant_type = RefreshTokenGrantType[User]( + grant_type = RefreshTokenGrantType( db, client_id=client_id, client_secret=client_secret ) diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py index 6a81f2f..5f9e847 100644 --- a/tests/test_request_validator.py +++ b/tests/test_request_validator.py @@ -13,14 +13,14 @@ ) from tests import factories -from tests.classes import AuthorizationContext, User +from tests.classes import AuthorizationContext @pytest.mark.asyncio async def test_insecure_transport_error(server: AuthorizationServer): request_url = "http://localhost" - request = Request[User](url=request_url, method="GET") + request = Request(url=request_url, method="GET") response = await server.create_authorization_response(request) assert response.status_code == HTTPStatus.FOUND @@ -156,7 +156,6 @@ async def test_invalid_response_type(): url=request_url, query=query, method="GET", - user=username, ) response = await server.create_authorization_response(request) assert response.status_code == HTTPStatus.FOUND @@ -182,10 +181,9 @@ async def test_anonymous_user(context: AuthorizationContext): code_challenge=code_challenge, ) - request = Request[User](url=request_url, query=query, method="GET") + request = Request(url=request_url, query=query, method="GET") response = await server.create_authorization_response(request) - assert response.status_code == HTTPStatus.UNAUTHORIZED - assert response.content["error"] == "invalid_client" + assert response.status_code == HTTPStatus.FOUND @pytest.mark.asyncio From 73525027279032ae3d2fcb65f339943f1f136003 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 17:22:55 +0400 Subject: [PATCH 46/57] fix: explicitly show the response error --- examples/fastapi_example.py | 62 ++++++++++++++++++++++++++++-------- examples/shared/__init__.py | 10 +++--- examples/shared/config.py | 14 ++++---- examples/shared/storage.py | 12 +++++-- tests/oidc/core/test_flow.py | 22 ++----------- 5 files changed, 73 insertions(+), 47 deletions(-) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index fc5aa67..7adb5fd 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -6,7 +6,7 @@ import json import html -from http import HTTPStatus +import logging from typing import Optional, cast from fastapi import FastAPI, Form, Request, Depends, Response @@ -29,6 +29,13 @@ app.add_middleware(SessionMiddleware) +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s", + datefmt="%H:%M:%S", +) + + async def get_auth_server() -> AuthServer: """ initialize oauth authorization server @@ -74,8 +81,24 @@ async def authorize( oauth2 authorization endpoint using aioauth """ oauthreq = await to_request(request) + user = request.session.get("user", None) + response = await oauth.create_authorization_response(oauthreq) - if response.status_code == HTTPStatus.UNAUTHORIZED: + + # A demonstration example of request validation before checking the user's credentials. + # See a discussion here: https://github.com/aliev/aioauth/issues/101 + if response.status_code >= 400: + content = f""" + + +

{response.content['error']}

+

{response.content['description']}

+ + + """ + return HTMLResponse(content, status_code=response.status_code) + + if user is None: request.session["oauth"] = oauthreq return RedirectResponse("/login") return to_response(response) @@ -155,18 +178,29 @@ async def approve(request: Request): if "user" not in request.session: redirect = request.url_for("login") return RedirectResponse(redirect) - oauthreq: OAuthRequest = request.session["oauth"] - content = f""" - - -

{oauthreq.query.client_id} would like permissions.

- - - - - - - """ + + oauth = request.session.get("oauth", None) + if oauth: + oauthreq: OAuthRequest = request.session["oauth"] + content = f""" + + +

{oauthreq.query.client_id} would like permissions.

+
+ + +
+ + + """ + else: + content = f""" + + +

Hello, {request.session['user'].username}.

+ + + """ return HTMLResponse(content) diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py index ec88b6e..7d0e4a7 100644 --- a/examples/shared/__init__.py +++ b/examples/shared/__init__.py @@ -20,7 +20,7 @@ "AuthServer", "BackendStore", "engine", - "config", + "app_config", "settings", "try_login", "lifespan", @@ -32,8 +32,8 @@ "sqlite+aiosqlite:///:memory:", echo=False, future=True ) -config = load_config(CONFIG_PATH) -settings = config.settings +app_config = load_config(CONFIG_PATH) +settings = app_config.settings async def try_login(username: str, password: str) -> Optional[User]: @@ -59,9 +59,9 @@ async def lifespan(*_): await conn.run_sync(SQLModel.metadata.create_all) # create test records async with AsyncSession(engine) as session: - for user in config.fixtures.users: + for user in app_config.fixtures.users: session.add(user) - for client in config.fixtures.clients: + for client in app_config.fixtures.clients: session.add(client) await session.commit() yield diff --git a/examples/shared/config.py b/examples/shared/config.py index ccf89b6..6f10c24 100644 --- a/examples/shared/config.py +++ b/examples/shared/config.py @@ -10,13 +10,6 @@ from .models import User, Client -def load_config(fpath: str) -> "Config": - """load configuration from filepath""" - with open(fpath, "r") as f: - json = f.read() - return Config.model_validate_json(json) - - class Fixtures(BaseModel): users: List[User] clients: List[Client] @@ -25,3 +18,10 @@ class Fixtures(BaseModel): class Config(BaseModel): fixtures: Fixtures settings: Settings + + +def load_config(fpath: str) -> Config: + """load configuration from filepath""" + with open(fpath, "r") as f: + json = f.read() + return Config.model_validate_json(json) diff --git a/examples/shared/storage.py b/examples/shared/storage.py index adfc50d..4d5f442 100644 --- a/examples/shared/storage.py +++ b/examples/shared/storage.py @@ -113,7 +113,11 @@ async def get_authorization_code( ) -> Optional[AuthorizationCode]: """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = ( + select(AuthCodeTable) + .where(AuthCodeTable.client_id == client_id) + .where(AuthCodeTable.code == code) + ) result = (await self.session.exec(sql)).one_or_none() if result is not None: return AuthorizationCode( @@ -138,7 +142,11 @@ async def delete_authorization_code( ) -> None: """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = ( + select(AuthCodeTable) + .where(AuthCodeTable.client_id == client_id) + .where(AuthCodeTable.code == code) + ) result = (await self.session.exec(sql)).one() await self.session.delete(result) await self.session.commit() diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py index 6ef920c..d793a7f 100644 --- a/tests/oidc/core/test_flow.py +++ b/tests/oidc/core/test_flow.py @@ -1,5 +1,4 @@ from http import HTTPStatus -from typing import Optional import pytest @@ -8,27 +7,12 @@ generate_token, ) -from tests.classes import User from tests.utils import check_request_validators @pytest.mark.asyncio -@pytest.mark.parametrize( - "user, expected_status_code", - [ - ("username", HTTPStatus.FOUND), - (None, HTTPStatus.FOUND), - ], -) -async def test_authorization_endpoint_allows_prompt_query_param( - expected_status_code: HTTPStatus, - user: Optional[User], - context_factory, -): - if user is None: - context = context_factory() - else: - context = context_factory(users={user: "password"}) +async def test_authorization_endpoint_allows_prompt_query_param(context_factory): + context = context_factory() server = context.server client = context.clients[0] client_id = client.client_id @@ -52,4 +36,4 @@ async def test_authorization_endpoint_allows_prompt_query_param( await check_request_validators(request, server.create_authorization_response) response = await server.create_authorization_response(request) - assert response.status_code == expected_status_code + assert response.status_code == HTTPStatus.FOUND From c215ca416cda8f47ea207b7fde2ece0159f8e194 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 26 Jan 2025 20:20:18 +0400 Subject: [PATCH 47/57] fix: passing a user through the extra argument --- examples/fastapi_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 7adb5fd..f736883 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -167,7 +167,6 @@ async def login_submit( request.session["user"] = user redirect = request.url_for("approve") return RedirectResponse(redirect, status_code=303) - # # sign in user @app.get("/approve") @@ -213,8 +212,8 @@ async def approve_submit( """ scope approval form submission handler """ - oauthreq = request.session["oauth"] - oauthreq.user = request.session["user"] + oauthreq: OAuthRequest = request.session["oauth"] + oauthreq.extra["user"] = request.session["user"] if not approval: # generate error response on deny error = AccessDeniedError(oauthreq, "User rejected scopes") From 27809eab819842bec8ba1b6641610aad5aee454b Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 1 Feb 2025 19:20:50 +0400 Subject: [PATCH 48/57] chore: removed "type: ignore" in response_type to be more explicit --- aioauth/response_type.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index 6bd1063..8ec8a92 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -138,6 +138,10 @@ class ResponseTypeAuthorizationCode(ResponseTypeBase[UserType]): async def create_authorization_response( self, request: Request[UserType], client: Client[UserType] ) -> AuthorizationCodeResponse: + assert request.query.response_type, ( + "`response_type` cannot be an empty string or `None`. " + "Please make sure you call `validate_request` before calling this method." + ) authorization_code = await self.storage.create_authorization_code( client_id=client.client_id, code=generate_token(42), @@ -146,7 +150,7 @@ async def create_authorization_response( nonce=request.query.nonce, redirect_uri=request.query.redirect_uri, request=request, - response_type=request.query.response_type, # type: ignore + response_type=request.query.response_type, scope=request.query.scope, ) return AuthorizationCodeResponse( From bfddde34bd3d84543053a8327b8f8ebe81acbd19 Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Sun, 2 Feb 2025 13:46:40 -0700 Subject: [PATCH 49/57] fix: support dynamic type assignmnt from wrapped function --- aioauth/utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/aioauth/utils.py b/aioauth/utils.py index d380e71..088b6fd 100644 --- a/aioauth/utils.py +++ b/aioauth/utils.py @@ -30,6 +30,7 @@ Set, Tuple, Type, + TypeVar, Union, ) from urllib.parse import quote, urlencode, urlparse, urlunsplit @@ -278,9 +279,15 @@ def build_error_response( ) +T = TypeVar("T") + + def catch_errors_and_unavailability( skip_redirect_on_exc: Tuple[Type[OAuth2Error], ...] = (OAuth2Error,) -) -> Callable[..., Callable[..., Coroutine[Any, Any, Response]]]: +) -> Callable[ + [Callable[..., Coroutine[Any, Any, T]]], + Callable[..., Coroutine[Any, Any, Union[T, Response]]], +]: """ Decorator that adds error catching to the function passed. @@ -290,11 +297,15 @@ def catch_errors_and_unavailability( A callable with error catching capabilities. """ - def decorator(f) -> Callable[..., Coroutine[Any, Any, Response]]: + def decorator( + f: Callable[..., Coroutine[Any, Any, T]] + ) -> Callable[..., Coroutine[Any, Any, Union[T, Response]]]: @functools.wraps(f) - async def wrapper(self, request: Request, *args, **kwargs) -> Response: + async def wrapper( + self, request: Request, *args, **kwargs + ) -> Union[T, Response]: try: - response = await f(self, request, *args, **kwargs) + response: Union[T, Response] = await f(self, request, *args, **kwargs) except Exception as exc: response = build_error_response( exc=exc, request=request, skip_redirect_on_exc=skip_redirect_on_exc From f921eedc155462273b21d2ad8595e2c7928ad8cc Mon Sep 17 00:00:00 2001 From: imgurbot12 Date: Sun, 2 Feb 2025 16:20:37 -0700 Subject: [PATCH 50/57] fix: allow error handling on resposne-gen up to developer --- aioauth/server.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/aioauth/server.py b/aioauth/server.py index 954ed2e..3a05dd7 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -418,13 +418,6 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: auth_state.grants.append((response_type, client)) return auth_state - @catch_errors_and_unavailability( - skip_redirect_on_exc=( - MethodNotAllowedError, - InvalidClientError, - InvalidRedirectURIError, - ) - ) async def create_authorization_response( self, auth_state: AuthorizationState, From d9377ac3cbad6e3f910310670cce10a96a96a340 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Tue, 4 Feb 2025 01:53:57 +0400 Subject: [PATCH 51/57] feat: splitting create_authorization_response into two methods --- aioauth/server.py | 26 +++++++++++++++----------- examples/fastapi_example.py | 14 +++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/aioauth/server.py b/aioauth/server.py index 3a05dd7..8adfcdf 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -344,16 +344,9 @@ async def token(request: fastapi.Request) -> fastapi.Response: content=content, status_code=HTTPStatus.OK, headers=default_headers ) - @catch_errors_and_unavailability( - skip_redirect_on_exc=( - MethodNotAllowedError, - InvalidClientError, - InvalidRedirectURIError, - ) - ) async def validate_authorization_request( self, request: Request - ) -> Union[Response, AuthorizationState]: + ) -> AuthorizationState: """ Endpoint to interact with the resource owner and obtain an authoriation grant. @@ -412,15 +405,15 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: raise UnsupportedResponseTypeError(request=request, state=state) auth_state = AuthorizationState(request, response_type_list, grants=[]) + for ResponseTypeClass in response_type_classes: response_type = ResponseTypeClass(storage=self.storage) client = await response_type.validate_request(request) auth_state.grants.append((response_type, client)) return auth_state - async def create_authorization_response( - self, - auth_state: AuthorizationState, + async def _create_authorization_response( + self, auth_state: AuthorizationState ) -> Response: """ Endpoint to interact with the resource owner and obtain an @@ -529,6 +522,17 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: content=content, ) + @catch_errors_and_unavailability( + skip_redirect_on_exc=( + MethodNotAllowedError, + InvalidClientError, + InvalidRedirectURIError, + ) + ) + async def create_authorization_response(self, request: Request) -> Response: + auth_state = await self.validate_authorization_request(request) + return await self._create_authorization_response(auth_state) + @catch_errors_and_unavailability() async def revoke_token(self, request: Request) -> Response: """Endpoint to revoke an access token or refresh token. diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index fa720db..05aa19c 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -15,7 +15,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from aioauth.collections import HTTPHeaderDict -from aioauth.errors import AccessDeniedError +from aioauth.errors import AccessDeniedError, OAuth2Error from aioauth.requests import Post, Query from aioauth.requests import Request as OAuthRequest from aioauth.responses import Response as OAuthResponse @@ -83,12 +83,16 @@ async def authorize( """ # validate initial request and return error response (if supplied) oauthreq = await to_request(request) - response = await oauth.validate_authorization_request(oauthreq) - if isinstance(response, OAuthResponse): + + try: + state = await oauth.validate_authorization_request(oauthreq) + except OAuth2Error as exc: + response = build_error_response(exc=exc, request=oauthreq) return to_response(response) + # redirect to login if user information is missing user = request.session.get("user", None) - request.session["oauth"] = response + request.session["oauth"] = state if user is None: return RedirectResponse("/login") # otherwise redirect to approval @@ -213,7 +217,7 @@ async def approve_submit( response = build_error_response(error, state.request, skip_redirect_on_exc=()) else: # process authorize request - response = await oauth.create_authorization_response(state) + response = await oauth._create_authorization_response(state) return to_response(response) From e7a8bd988038b443332d2a7d53d37405e8fc26e4 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 8 Feb 2025 13:19:37 +0400 Subject: [PATCH 52/57] chore: renamed _create_authorization_response to finalize_authorization_response --- aioauth/server.py | 71 ++++++++++++++++++++++--------------- examples/fastapi_example.py | 2 +- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/aioauth/server.py b/aioauth/server.py index 8adfcdf..39e32ab 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -29,8 +29,6 @@ from .collections import HTTPHeaderDict from .constances import default_headers from .errors import ( - InsecureTransportError, - InvalidClientError, InvalidRedirectURIError, InvalidRequestError, MethodNotAllowedError, @@ -46,6 +44,10 @@ PasswordGrantType, RefreshTokenGrantType, ) +from .errors import ( + InsecureTransportError, + InvalidClientError, +) from .response_type import ( ResponseTypeAuthorizationCode, ResponseTypeIdToken, @@ -412,39 +414,21 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: auth_state.grants.append((response_type, client)) return auth_state - async def _create_authorization_response( + async def finalize_authorization_response( self, auth_state: AuthorizationState ) -> Response: """ - Endpoint to interact with the resource owner and obtain an - authorization grant. - Create an authorization response after validation. - For more information see - `RFC6749 section 4.1.1 `_. - - Example: - Below is an example utilizing FastAPI as the server framework. - .. code-block:: python + Finalizes the authorization response based on the provided authorization state. - from aioauth.fastapi.utils import to_oauth2_request, to_fastapi_response - - @app.post("/authorize") - async def authorize(request: fastapi.Request) -> fastapi.Response: - # Converts a fastapi.Request to an aioauth.Request. - oauth2_request: aioauth.Request = await to_oauth2_request(request) - # Validate the oauth request - auth_state: aioauth.AuthState = await server.validate_authorization_request(oauth2_request) - # Creates the response via this function call. - oauth2_response: aioauth.Response = await server.create_authorization_response(auth_state) - # Converts an aioauth.Response to a fastapi.Response. - response: fastapi.Response = await to_fastapi_response(oauth2_response) - return response + This is the final step in creating an authorization response before sending it to + the client. Args: - auth_state: An :py:class:`aioauth.server.AuthState` object. + auth_state (AuthorizationState): The current authorization state, including the + original request, response types, and associated grants. Returns: - response: An :py:class:`aioauth.responses.Response` object. + Response: An HTTP response object with the appropriate redirection headers and content. """ request = auth_state.request state = auth_state.request.query.state @@ -530,8 +514,39 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: ) ) async def create_authorization_response(self, request: Request) -> Response: + """ + Endpoint to interact with the resource owner and obtain an + authorization grant. + Create an authorization response after validation. + For more information see + `RFC6749 section 4.1.1 `_. + + Example: + Below is an example utilizing FastAPI as the server framework. + .. code-block:: python + + from aioauth.fastapi.utils import to_oauth2_request, to_fastapi_response + + @app.post("/authorize") + async def authorize(request: fastapi.Request) -> fastapi.Response: + # Converts a fastapi.Request to an aioauth.Request. + oauth2_request: aioauth.Request = await to_oauth2_request(request) + # Validate the oauth request + auth_state: aioauth.AuthState = await server.validate_authorization_request(oauth2_request) + # Creates the response via this function call. + oauth2_response: aioauth.Response = await server.create_authorization_response(auth_state) + # Converts an aioauth.Response to a fastapi.Response. + response: fastapi.Response = await to_fastapi_response(oauth2_response) + return response + + Args: + auth_state: An :py:class:`aioauth.server.AuthState` object. + + Returns: + response: An :py:class:`aioauth.responses.Response` object. + """ auth_state = await self.validate_authorization_request(request) - return await self._create_authorization_response(auth_state) + return await self.finalize_authorization_response(auth_state) @catch_errors_and_unavailability() async def revoke_token(self, request: Request) -> Response: diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 05aa19c..f3de6b0 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -217,7 +217,7 @@ async def approve_submit( response = build_error_response(error, state.request, skip_redirect_on_exc=()) else: # process authorize request - response = await oauth._create_authorization_response(state) + response = await oauth.finalize_authorization_response(state) return to_response(response) From f42764dc954a3d8ed6ad9af43e3650682a6dce90 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sat, 8 Feb 2025 22:44:48 +0400 Subject: [PATCH 53/57] refactor: enhance type annotations for grants and response types in AuthorizationServer --- aioauth/server.py | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/aioauth/server.py b/aioauth/server.py index 39e32ab..c92de67 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -19,7 +19,7 @@ from dataclasses import asdict, dataclass from http import HTTPStatus -from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args +from typing import Dict, List, Optional, Tuple, Type, Union, get_args, Set from .models import Client from .requests import Request @@ -83,20 +83,46 @@ class AuthorizationState: response_type_list: List[ResponseType] """Supported ResponseTypes Collected During Initial Request Validation""" - grants: List[Tuple[ResponseTypeAuthorizationCode, Client]] + grants: List[ + Tuple[ + Union[ + ResponseTypeToken, + ResponseTypeAuthorizationCode, + ResponseTypeNone, + ResponseTypeIdToken, + ], + Client, + ] + ] """Collection of Supported GrantType Handlers and The Parsed Clients""" class AuthorizationServer: """Interface for initializing an OAuth 2.0 server.""" - response_types: Dict[ResponseType, Any] = { + response_types: Dict[ + ResponseType, + Union[ + type[ResponseTypeToken], + type[ResponseTypeAuthorizationCode], + type[ResponseTypeNone], + type[ResponseTypeIdToken], + ], + ] = { "token": ResponseTypeToken, "code": ResponseTypeAuthorizationCode, "none": ResponseTypeNone, "id_token": ResponseTypeIdToken, } - grant_types: Dict[GrantType, Any] = { + grant_types: Dict[ + GrantType, + Union[ + type[AuthorizationCodeGrantType], + type[ClientCredentialsGrantType], + type[PasswordGrantType], + type[RefreshTokenGrantType], + ], + ] = { "authorization_code": AuthorizationCodeGrantType, "client_credentials": ClientCredentialsGrantType, "password": PasswordGrantType, @@ -388,7 +414,14 @@ async def authorize(request: fastapi.Request) -> fastapi.Response: self.validate_request(request, ["GET", "POST"]) response_type_list = enforce_list(request.query.response_type) - response_type_classes = set() + response_type_classes: Set[ + Union[ + type[ResponseTypeToken], + type[ResponseTypeAuthorizationCode], + type[ResponseTypeNone], + type[ResponseTypeIdToken], + ] + ] = set() state = request.query.state if not response_type_list: From c6e772f6130953ec117eb6b3aa292b198364059b Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 9 Feb 2025 14:19:22 +0400 Subject: [PATCH 54/57] chore: updated the link to the FastAPI example in README.md --- README.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8f5abea..12a9055 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,8 @@ There are few great OAuth frameworks for Python like [oauthlib](https://github.c python -m pip install aioauth ``` -## FastAPI +## Examples -FastAPI integration stored on separated [aioauth-fastapi](https://github.com/aliev/aioauth-fastapi) repository and can be installed via the command: - -``` -python -m pip install aioauth[fastapi] -``` - -[aioauth-fastapi](https://github.com/aliev/aioauth-fastapi) repository contains demo example which I recommend to look. +The project example is located in the [examples](examples) directory and uses FastAPI as the server. ## [API Reference and User Guide](https://aliev.me/aioauth/) From 815975fbe4f4c108d70f9ce90f97b67cd7847fa7 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 9 Feb 2025 14:36:04 +0400 Subject: [PATCH 55/57] fix: removed the dynamic type assignment from wrapped function --- aioauth/utils.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/aioauth/utils.py b/aioauth/utils.py index 088b6fd..2a943b9 100644 --- a/aioauth/utils.py +++ b/aioauth/utils.py @@ -30,7 +30,6 @@ Set, Tuple, Type, - TypeVar, Union, ) from urllib.parse import quote, urlencode, urlparse, urlunsplit @@ -279,15 +278,9 @@ def build_error_response( ) -T = TypeVar("T") - - def catch_errors_and_unavailability( skip_redirect_on_exc: Tuple[Type[OAuth2Error], ...] = (OAuth2Error,) -) -> Callable[ - [Callable[..., Coroutine[Any, Any, T]]], - Callable[..., Coroutine[Any, Any, Union[T, Response]]], -]: +) -> Callable[..., Callable[..., Coroutine[Any, Any, Response]]]: """ Decorator that adds error catching to the function passed. @@ -298,14 +291,12 @@ def catch_errors_and_unavailability( """ def decorator( - f: Callable[..., Coroutine[Any, Any, T]] - ) -> Callable[..., Coroutine[Any, Any, Union[T, Response]]]: + f: Callable[..., Coroutine[Any, Any, Response]] + ) -> Callable[..., Coroutine[Any, Any, Response]]: @functools.wraps(f) - async def wrapper( - self, request: Request, *args, **kwargs - ) -> Union[T, Response]: + async def wrapper(self, request: Request, *args, **kwargs) -> Response: try: - response: Union[T, Response] = await f(self, request, *args, **kwargs) + response = await f(self, request, *args, **kwargs) except Exception as exc: response = build_error_response( exc=exc, request=request, skip_redirect_on_exc=skip_redirect_on_exc From 09a85ce0dd7f926dac653264b98d0db3945ea76f Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 9 Feb 2025 21:11:05 +0400 Subject: [PATCH 56/57] fix: checkout version in docs-publish.yml --- .github/workflows/docs-publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-publish.yml b/.github/workflows/docs-publish.yml index 816b57c..2cfde9a 100644 --- a/.github/workflows/docs-publish.yml +++ b/.github/workflows/docs-publish.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checks out repo - uses: actions/checkout@v1 + uses: actions/checkout@v2 - name: Generates HTML documentation uses: synchronizing/sphinx-action@master From 78f08763adce6666ea67400872925c69e956468c Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Sun, 9 Feb 2025 21:12:21 +0400 Subject: [PATCH 57/57] fix: upload-artifact version in docs-publish.yml --- .github/workflows/docs-publish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs-publish.yml b/.github/workflows/docs-publish.yml index 2cfde9a..b462c40 100644 --- a/.github/workflows/docs-publish.yml +++ b/.github/workflows/docs-publish.yml @@ -27,12 +27,12 @@ jobs: docs-folder: "docs/" - name: Saves the HTML build documentation - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: path: docs/build/html/ - name: Saves the PDF build documentation - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: path: docs/build/latex/aioauth.pdf