IEEE.org     |     IEEE Xplore Digital Library     |     IEEE Standards     |     IEEE Spectrum     |     More Sites

Verified Commit f3e95e4f authored by Emi Simpson's avatar Emi Simpson
Browse files

[api] Add user info to the Session object

parent 95b65d0d
Pipeline #1071 passed with stage
in 52 seconds
......@@ -4,50 +4,24 @@ from werkzeug.wrappers.response import Response
from base64 import b64encode
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, TypedDict
from typing import Optional
from mystic import queries
from mystic.queries import direct_auth
from mystic.queries.direct_auth import ValidatePasswordError
from mystic.api import after_running_execute_queries, is_r, json_request, make_response, R
from mystic.api.v1.errors import ApiErrorCode, BadPasswordError, handle_json_parse_error, InternalError, MalformedArgumentError, MismatchedTypeError, MissingFieldsError, UserDneError, NameTakenError
from mystic.api.v1.types import Auth, Session, User, user_from_database_user
from mystic.api.v1.util import check_all_present, first_error, get_typechecked
from mystic.queries import Query, direct_auth
from mystic.types import ApiJsonRequest
from mystic.utils import panic
class Auth(TypedDict):
"""
Information about what authentication / signup methods the server supports
"""
direct: bool
"""
Indicates that the server supports the `/api/v1/auth/direct` endpoint for direct auth
"""
sso: bool
"""
Indicates that the server supports the `/api/v1/auth/sso` endpoint for single sign on
"""
class Session(TypedDict):
"""
Wraps a session token
"""
token: str
"""
A session token which can be used for auth later on
This is guaranteed to be valid base64 and decode to exactly 18 bytes, meaning that its
length is guaranteed to be exactly 24 characters.
"""
def create_session(token: bytes) -> Session:
def create_session(token: bytes, user: User) -> Session:
"""
A shorthand for creating a :class:`Session` from an unencoded session token
"""
return Session(token=b64encode(token).decode('ASCII'))
return Session(token=b64encode(token).decode('ASCII'), user=user)
@dataclass(frozen=True)
class AuthModule:
......@@ -126,11 +100,13 @@ class AuthModule:
message=f"The database has a malformed password hash on file for this user"))
}[error]),
# Once we have the user's ID...
transformation=lambda user_id: queries.MappedQuery(
transformation=lambda user: queries.MappedQuery(
# Create a session token for the user...
queries.CreateSession(user_id),
queries.CreateSession(user.user_id),
# And return it as a successful response
on_success=lambda token: R(HTTPStatus.OK, create_session(token)),
on_success=lambda token: R(
HTTPStatus.OK,
create_session(token, user_from_database_user(user))),
on_error=lambda noreturn: noreturn)
)
......@@ -176,7 +152,15 @@ class AuthModule:
# Create a session for the new account
queries.CreateSession(uid),
# And return it in as JSON
on_success=lambda token: R(HTTPStatus.OK, create_session(token)),
on_success=lambda token: R(
HTTPStatus.OK,
create_session(token, User(
#Including the information about the user
id=uid,
username=username,
first=first_name,
last=last_name,
avatar='unimplemented'))),
on_error=lambda noreturn: noreturn))),
# Tell the SQL server that this transaction does perform IO
read_only = False)
......
......@@ -230,3 +230,38 @@ class Instance(TypedDict):
"""
The backend types which are currently supported by this server
"""
class Auth(TypedDict):
"""
Information about what authentication / signup methods the server supports
"""
direct: bool
"""
Indicates that the server supports the `/api/v1/auth/direct` endpoint for direct auth
"""
sso: bool
"""
Indicates that the server supports the `/api/v1/auth/sso` endpoint for single sign on
"""
class Session(TypedDict):
"""
Wraps a session token
"""
token: str
"""
A session token which can be used for auth later on
This is guaranteed to be valid base64 and decode to exactly 18 bytes, meaning that its
length is guaranteed to be exactly 24 characters.
"""
user: User
"""
Information about the currently authenticated user
Equivilent to a call to /whoami
"""
from mystic.queries import Error, Finished, QueryRequest, QueryResult, SqlErrorCode, SqlIntegrityError
from mystic.queries import Error, Finished, QueryRequest, QueryResult, SqlErrorCode, SqlIntegrityError, UserInfo
from mystic.types import UserID
from enum import auto, Enum
from typing import Final, Literal, NamedTuple, Optional
......@@ -71,12 +71,12 @@ class ValidatePassword(NamedTuple):
def get_query(self) -> QueryRequest:
return QueryRequest('''
SELECT user_id, pass_hash
SELECT user_id, first_name, last_name, pass_hash
FROM users
NATURAL JOIN passwords
WHERE username = %s;
''', (self.username,))
def handle_results(self, results: QueryResult | SqlIntegrityError) -> Finished[UserID] | Error[ValidatePasswordError]:
def handle_results(self, results: QueryResult | SqlIntegrityError) -> Finished[UserInfo] | Error[ValidatePasswordError]:
if isinstance(results, SqlIntegrityError):
raise Exception(f"Unexpected IntegrityError in GetSources: {results}")
else:
......@@ -85,12 +85,14 @@ class ValidatePassword(NamedTuple):
return Error(ValidatePasswordError.BadUsername)
else:
user_id: UserID = record[0]
pass_hash: bytes = record[1]
first_name: str = record[1]
last_name: str = record[2]
pass_hash: bytes = record[3]
validity = try_validate_password(pass_hash, self.password)
if validity is not None:
return Error(validity)
else:
return Finished(user_id)
return Finished(UserInfo(user_id, self.username, first_name, last_name))
class SetPassword(NamedTuple):
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment