IEEE.org     |     IEEE Xplore Digital Library     |     IEEE Standards     |     IEEE Spectrum     |     More Sites

Verified Commit fcaa38ae authored by Emi Simpson's avatar Emi Simpson
Browse files

[api] Added an endpoint to refresh a session

parent 41521993
Pipeline #1086 failed with stage
in 46 seconds
......@@ -158,6 +158,10 @@ By looking at the code associated with the error, more information can be learne
* Authenticate with direct authentication `POST /auth/sso`
* `session: <session>`
* Returns a `Session`
* Renew an existing session `PUT /auth/refresh`
* Takes no arguments, but should be authenticated
* Returns a `Session`
* The old session is now invalidated
## Authentication
......
......@@ -8,7 +8,7 @@ from mystic import coordination, queries
from mystic.api import after_running_execute_coordinator_requests, after_running_execute_queries, is_r, json_request, make_response, parse_session_token, query_args_request, R
from mystic.api.v1 import auth
from mystic.api.v1.errors import ApiErrorCode, BadSessionError, EndpointDNEError, ExpiredSessionError, handle_json_parse_error, MalformedArgumentError, MethodNotAllowedError, MissingFieldsError, MismatchedTypeError, NameTakenError, NoPermissionError, RedundantEntityError, UnknownError
from mystic.api.v1.types import Instance, Job, job_from_coordinator_job, lookup_backend, Project, project_from_database_project, Source, Success, SUCCESS, User, user_from_database_user
from mystic.api.v1.types import create_session, Instance, Job, job_from_coordinator_job, lookup_backend, Project, project_from_database_project, Session, Source, Success, SUCCESS, User, user_from_database_user
from mystic.api.v1.util import check_all_present, first_error, get_typechecked, unk_err
from mystic.coordination import CoordinatorFlow
from mystic.queries import Query
......@@ -38,9 +38,51 @@ GetUser: TypeAlias = Query[UserID, R[ExpiredSessionError]]
An alias for the type of query made available by :func:`authenticate`
"""
Req = TypeVar('Req', ApiJsonRequest, ApiGetRequest)
def try_extract_session(req: Req) -> BadSessionError | bytes:
authorization = req.headers.get('Authorization', None)
# Case 1: The authization header wasn't even provided
if authorization is None:
return BadSessionError(
code = ApiErrorCode.BadSession,
message = "Attempted to access an endpoint that requires authentication, but the Authorization header wasn't present",
malformed = False,)
# Case 2: The client is trying to use some other authorization method than bearer
elif not authorization.lower().startswith('bearer '):
if parse_session_token(authorization):
# Case 2.1: It looks like the client might've just forgotten to type
# Bearer, so we can give them a push in the right direction
return BadSessionError(
code = ApiErrorCode.BadSession,
message =
'It looks like you sent a session token but left out the '
+ '<auth-scheme> parameter in your Authorization header. Perhaps '
+f'you meant to send "Authorization: Bearer {authorization}" '
+ 'instead? See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization#syntax',
malformed = False,)
else:
# Case 2.2: Standard response
return BadSessionError(
code = ApiErrorCode.BadSession,
message = "Attempted to use an Authorization header with a method other than Bearer",
malformed = False,)
else:
# We found the token, so lets try decoding it
encoded_token = authorization[7:].lstrip()
token = parse_session_token(encoded_token)
if token is None:
# Case 4: The token was malformed
return BadSessionError(
code = ApiErrorCode.BadSession,
message = "The Bearer (session) token was malformed",
malformed = True,)
else:
# Case 5: Everything looks good!
return token
P = ParamSpec('P')
T = TypeVar('T')
Req = TypeVar('Req', ApiJsonRequest, ApiGetRequest)
def authenticate(func: Callable[Concatenate[Req, GetUser, P], T]) -> Callable[Concatenate[Req, P], T | R[BadSessionError]]:
"""
A decorator which wraps an authenticated endpoint
......@@ -58,52 +100,18 @@ def authenticate(func: Callable[Concatenate[Req, GetUser, P], T]) -> Callable[Co
"""
@wraps(func)
def wrapper(req: Req, *args: P.args, **kwargs: P.kwargs) -> T | R[BadSessionError]:
authorization = req.headers.get('Authorization', None)
# Case 1: The authization header wasn't even provided
if authorization is None:
return R(HTTPStatus.UNAUTHORIZED, BadSessionError(
code = ApiErrorCode.BadSession,
message = "Attempted to access an endpoint that requires authentication, but the Authorization header wasn't present",
malformed = False,))
# Case 2: The client is trying to use some other authorization method than bearer
elif not authorization.lower().startswith('bearer '):
if parse_session_token(authorization):
# Case 2.1: It looks like the client might've just forgotten to type
# Bearer, so we can give them a push in the right direction
return R(HTTPStatus.UNAUTHORIZED, BadSessionError(
code = ApiErrorCode.BadSession,
message =
'It looks like you sent a session token but left out the '
+ '<auth-scheme> parameter in your Authorization header. Perhaps '
+f'you meant to send "Authorization: Bearer {authorization}" '
+ 'instead? See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization#syntax',
malformed = False,))
else:
# Case 2.2: Standard response
return R(HTTPStatus.UNAUTHORIZED, BadSessionError(
code = ApiErrorCode.BadSession,
message = "Attempted to use an Authorization header with a method other than Bearer",
malformed = False,))
token = try_extract_session(req)
if not isinstance(token, bytes):
return R(HTTPStatus.UNAUTHORIZED, token)
else:
# We found the token, so lets try decoding it
encoded_token = authorization[7:].lstrip()
token = parse_session_token(encoded_token)
if token is None:
# Case 4: The token was malformed
return R(HTTPStatus.UNAUTHORIZED, BadSessionError(
code = ApiErrorCode.BadSession,
message = "The Bearer (session) token was malformed",
malformed = True,))
else:
# Case 5: Everything looks good! Prepare a query which validates the
# session, and pass it to the function
return func(req, queries.MappedQuery(
queries.ValidateSession(token),
on_success=lambda uid: uid,
on_error=lambda _: R(HTTPStatus.UNAUTHORIZED, ExpiredSessionError(
code = ApiErrorCode.ExpiredSession,
message = "This session has expired, please reauthenticate"))
), *args, **kwargs)
# Prepare a query which validates the session, and pass it to the function
return func(req, queries.MappedQuery(
queries.ValidateSession(token),
on_success=lambda uid: uid,
on_error=lambda _: R(HTTPStatus.UNAUTHORIZED, ExpiredSessionError(
code = ApiErrorCode.ExpiredSession,
message = "This session has expired, please reauthenticate"))
), *args, **kwargs)
return wrapper
......@@ -486,6 +494,47 @@ def delete_source(_: ApiGetRequest, get_user: GetUser, pid: int, sid: SourceID)
code=ApiErrorCode.NoPermission,
message="Either this source does not exist, or you lack permission to manage it"))))
@bp.put('/auth/refresh')
@make_response
@query_args_request
@after_running_execute_queries(is_r)
def refresh_session(r: ApiGetRequest) -> R[BadSessionError] | Query[R[Session], R[ExpiredSessionError | UnknownError]]:
token = try_extract_session(r)
if not isinstance(token, bytes):
return R(HTTPStatus.UNAUTHORIZED, token)
else:
return queries.Transaction(
queries.BoundQuery(
queries.MappedQuery(
# Invalidate the old session
queries.InvalidateSession(token),
# Passing on the user ID for later use
on_success=lambda user: user,
# And producing an ExpiredSessionError if there was any issue
on_error=lambda _: R(HTTPStatus.UNAUTHORIZED, ExpiredSessionError(
code=ApiErrorCode.ExpiredSession,
message='Attempted to renew a session that had already expired, or simply never existed'))),
# Then, using the user ID from before...
transformation=lambda user: queries.BoundQuery(
queries.MappedQuery(
# Create a new session
queries.CreateSession(user),
# Passing on the new session token
on_success=lambda token: token,
# (and this should never fail)
on_error=lambda _: unk_err('CreateSession claims a uid provided by DeleteSession does not exist :<')),
# Finally, take the token,
transformation=lambda token: queries.MappedQuery(
# And some information about the user
queries.RetrieveUserInfo(user),
# And bundle it together for the client
on_success=lambda user_info: R(HTTPStatus.OK, create_session(
token,
user_from_database_user(user_info))),
# (again, this should never fail)
on_error=lambda _: unk_err('uid by DeleteSession validated by CreateSession but not RetrieveUserInfo')))),
read_only=False)
def get_bp(auth_mod: auth.AuthModule) -> Blueprint:
bp = Blueprint("api-v1", __name__, url_prefix="/api/v1/")
......
......@@ -12,18 +12,12 @@ from mystic.queries import direct_auth
from mystic.queries.direct_auth import ValidatePasswordError
from mystic.api import after_running_execute_queries, is_r, json_request, make_response, query_args_request, R
from mystic.api.v1.errors import ApiErrorCode, BadPasswordError, handle_json_parse_error, InternalError, MalformedArgumentError, MismatchedTypeError, MissingFieldsError, UnknownError, UserDneError, NameTakenError
from mystic.api.v1.types import Auth, Session, User, user_from_database_user
from mystic.api.v1.types import Auth, create_session, Session, User, user_from_database_user
from mystic.api.v1.util import check_all_present, first_error, get_typechecked, unk_err
from mystic.queries import Query, direct_auth
from mystic.types import ApiGetRequest, ApiJsonRequest, Url, UserID
from mystic.utils import panic
def create_session(token: bytes, user: User) -> Session:
"""
A shorthand for creating a :class:`Session` from an unencoded session token
"""
return Session(token=b64encode(token).decode('ASCII'), user=user)
def _add_query_param(url: Url, key: str, value: str) -> Url:
"""
A quick-and-dirty algorithm for adding a query parameter to a URL
......
......@@ -276,3 +276,10 @@ class Session(TypedDict):
Equivilent to a call to /whoami
"""
def create_session(token: bytes, user: User) -> Session:
"""
A shorthand for creating a :class:`Session` from an unencoded session token
"""
from base64 import b64encode
return Session(token=b64encode(token).decode('ASCII'), user=user)
......@@ -20,8 +20,8 @@ def setup_database(c: Cursor) -> None:
CREATE TABLE IF NOT EXISTS sessions (
token BINARY(18) PRIMARY KEY,
user_id INTEGER NOT NULL,
created TIMESTAMP NOT NULL
DEFAULT CURRENT_TIMESTAMP,
expires TIMESTAMP NOT NULL
DEFAULT (CURRENT_TIMESTAMP + INTERVAL 1 WEEK),
FOREIGN KEY (user_id)
REFERENCES users(user_id)
ON DELETE CASCADE
......
......@@ -1121,9 +1121,8 @@ class CreateSession(NamedTuple):
"""
A :class:`Query` to create a new session for a given user
The created session will be marked as having been created on this day, thereby making
it valid for no less than one week. During this span of time, the session can be
passed to :class:`ValidateSession` to yield this user's ID.
By default, the session will be marked to expire in one week, although this can be
overriden using the :attr:`lifespan` parameter.
If successful, this will produce exactly an exactly 18-byte session token.
If the user with the provided ID does not exist, then this returns None
......@@ -1132,11 +1131,25 @@ class CreateSession(NamedTuple):
"""
The ID of the user to create a session for
"""
lifespan: int = 24 * 7
"""
The how long the token should be valid for, in hours
Non-positive values will cause a panic upon execution
"""
def get_query(self) -> QueryRequest:
assert self.lifespan >= 0, "Lifespan must be greater than zero"
from random import randbytes
# TODO switch to using RANDBYTES(18) once MariaDB 10.10 comes out
return QueryRequest('INSERT INTO sessions (token, user_id) VALUES (%s, %s) RETURNING token;',
(randbytes(18), self.user_id,))
return QueryRequest('''
INSERT INTO sessions (token, user_id, expires)
VALUES (%s, %s, CURRENT_TIMESTAMP + INTERVAL %s HOUR)
RETURNING token;
''', (randbytes(18), self.user_id, self.lifespan))
def handle_results(self, results: QueryResult | SqlIntegrityError) -> Finished[bytes] | Unfinished[bytes, None] | Error[None]:
if isinstance(results, SqlIntegrityError):
match results.error_code:
......@@ -1155,6 +1168,43 @@ class CreateSession(NamedTuple):
assert isinstance(token, bytes), 'This statement is expected to return bytes'
return Finished(token)
class InvalidateSession(NamedTuple):
"""
A :class:`Query` which marks a session as invalid
After execution, the session will no longer be usable to log in or authenticate any
request that uses the :class:`ValidateSession` query.
A successful result means that the token was deleted successfully, whereas an error
result means that there was never a token to begin with.
If successful, the :class:`UserID` of the user the session used to belong to is
returned.
"""
token: bytes
"""
The eighteen byte session token to void
Will raise an exception upon execution if the token is not exactly 18 bytes
"""
def get_query(self) -> QueryRequest:
assert len(self.token) == 18, "Invalid token passed"
return QueryRequest(
'DELETE FROM sessions WHERE token = %s RETURNING user_id;',
(self.token,))
def handle_results(self, results: QueryResult | SqlIntegrityError) -> Finished[UserID] | Error[None]:
if isinstance(results, SqlIntegrityError):
raise Exception(f'Unexpected sql error: {results}')
else:
res: Tuple[UserID] | None = results.next()
if res:
return Finished(res[0])
else:
return Error(None)
class CreateAccount(NamedTuple):
"""
A :class:`Query` to register a new account
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment