-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'development' into ft-dynamodb-extended
- Loading branch information
Showing
14 changed files
with
351 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,4 +17,5 @@ Flask-SQLAlchemy | |
pymongo | ||
boto3 | ||
mypy_boto3_dynamodb | ||
psycopg2-binary | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ pymongo | |
flask_sqlalchemy | ||
pymemcache | ||
boto3 | ||
mypy_boto3_dynamodb | ||
mypy_boto3_dynamodb | ||
psycopg2-binary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .postgresql import PostgreSqlSession, PostgreSqlSessionInterface # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from psycopg2 import sql | ||
|
||
|
||
class Queries: | ||
def __init__(self, schema: str, table: str) -> None: | ||
"""Class to hold all the queries used by the session interface. | ||
Args: | ||
schema (str): The name of the schema to use for the session data. | ||
table (str): The name of the table to use for the session data. | ||
""" | ||
self.schema = schema | ||
self.table = table | ||
|
||
@property | ||
def create_schema(self) -> str: | ||
return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format( | ||
schema=sql.Identifier(self.schema) | ||
) | ||
|
||
@property | ||
def create_table(self) -> str: | ||
uq_idx = sql.Identifier(f"uq_{self.table}_session_id") | ||
expiry_idx = sql.Identifier(f"{self.table}_expiry_idx") | ||
return sql.SQL( | ||
"""CREATE TABLE IF NOT EXISTS {schema}.{table} ( | ||
session_id VARCHAR(255) NOT NULL PRIMARY KEY, | ||
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'), | ||
data BYTEA, | ||
expiry TIMESTAMP WITHOUT TIME ZONE | ||
); | ||
--- Unique session_id | ||
CREATE UNIQUE INDEX IF NOT EXISTS | ||
{uq_idx} ON {schema}.{table} (session_id); | ||
--- Index for expiry timestamp | ||
CREATE INDEX IF NOT EXISTS | ||
{expiry_idx} ON {schema}.{table} (expiry);""" | ||
).format( | ||
schema=sql.Identifier(self.schema), | ||
table=sql.Identifier(self.table), | ||
uq_idx=uq_idx, | ||
expiry_idx=expiry_idx, | ||
) | ||
|
||
@property | ||
def retrieve_session_data(self) -> str: | ||
return sql.SQL( | ||
"""--- If the current sessions is expired, delete it | ||
DELETE FROM {schema}.{table} | ||
WHERE session_id = %(session_id)s AND expiry < NOW(); | ||
--- Else retrieve it | ||
SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s; | ||
""" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) | ||
|
||
@property | ||
def upsert_session(self) -> str: | ||
return sql.SQL( | ||
"""INSERT INTO {schema}.{table} (session_id, data, expiry) | ||
VALUES (%(session_id)s, %(data)s, NOW() + %(ttl)s) | ||
ON CONFLICT (session_id) | ||
DO UPDATE SET data = %(data)s, expiry = NOW() + %(ttl)s; | ||
""" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) | ||
|
||
@property | ||
def delete_expired_sessions(self) -> str: | ||
return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format( | ||
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table) | ||
) | ||
|
||
@property | ||
def delete_session(self) -> str: | ||
return sql.SQL( | ||
"DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) | ||
|
||
@property | ||
def drop_sessions_table(self) -> str: | ||
return sql.SQL("DROP TABLE IF EXISTS {schema}.{table};").format( | ||
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import contextmanager | ||
from datetime import timedelta as TimeDelta | ||
from typing import Generator, Optional | ||
|
||
from flask import Flask | ||
from itsdangerous import want_bytes | ||
from psycopg2.extensions import connection as PsycoPg2Connection | ||
from psycopg2.extensions import cursor as PsycoPg2Cursor | ||
from psycopg2.pool import ThreadedConnectionPool | ||
|
||
from .._utils import retry_query | ||
from ..base import ServerSideSession, ServerSideSessionInterface | ||
from ..defaults import Defaults | ||
from ._queries import Queries | ||
|
||
|
||
class PostgreSqlSession(ServerSideSession): | ||
pass | ||
|
||
|
||
class PostgreSqlSessionInterface(ServerSideSessionInterface): | ||
"""A Session interface that uses PostgreSQL as a session storage. (`psycopg2` required) | ||
:param pool: A ``psycopg2.pool.ThreadedConnectionPool`` instance. | ||
:param key_prefix: A prefix that is added to all storage keys. | ||
:param use_signer: Whether to sign the session id cookie or not. | ||
:param permanent: Whether to use permanent session or not. | ||
:param sid_length: The length of the generated session id in bytes. | ||
:param serialization_format: The serialization format to use for the session data. | ||
:param table: The table name you want to use. | ||
:param schema: The db schema to use. | ||
:param cleanup_n_requests: Delete expired sessions on average every N requests. | ||
""" | ||
|
||
session_class = PostgreSqlSession | ||
ttl = False | ||
|
||
def __init__( | ||
self, | ||
app: Flask, | ||
pool: Optional[ThreadedConnectionPool] = Defaults.SESSION_POSTGRESQL, | ||
key_prefix: str = Defaults.SESSION_KEY_PREFIX, | ||
use_signer: bool = Defaults.SESSION_USE_SIGNER, | ||
permanent: bool = Defaults.SESSION_PERMANENT, | ||
sid_length: int = Defaults.SESSION_ID_LENGTH, | ||
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, | ||
table: str = Defaults.SESSION_POSTGRESQL_TABLE, | ||
schema: str = Defaults.SESSION_POSTGRESQL_SCHEMA, | ||
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, | ||
) -> None: | ||
if not isinstance(pool, ThreadedConnectionPool): | ||
raise TypeError("No valid ThreadedConnectionPool instance provided.") | ||
|
||
self.pool = pool | ||
|
||
self._table = table | ||
self._schema = schema | ||
|
||
self._queries = Queries(schema=self._schema, table=self._table) | ||
|
||
self._create_schema_and_table() | ||
|
||
super().__init__( | ||
app, | ||
key_prefix, | ||
use_signer, | ||
permanent, | ||
sid_length, | ||
serialization_format, | ||
cleanup_n_requests, | ||
) | ||
|
||
@contextmanager | ||
def _get_cursor( | ||
self, conn: Optional[PsycoPg2Connection] = None | ||
) -> Generator[PsycoPg2Cursor, None, None]: | ||
_conn: PsycoPg2Connection = conn or self.pool.getconn() | ||
|
||
assert isinstance(_conn, PsycoPg2Connection) | ||
try: | ||
with _conn, _conn.cursor() as cur: | ||
yield cur | ||
except Exception: | ||
raise | ||
finally: | ||
self.pool.putconn(_conn) | ||
|
||
@retry_query(max_attempts=3) | ||
def _create_schema_and_table(self) -> None: | ||
with self._get_cursor() as cur: | ||
cur.execute(self._queries.create_schema) | ||
cur.execute(self._queries.create_table) | ||
|
||
def _delete_expired_sessions(self) -> None: | ||
"""Delete all expired sessions from the database.""" | ||
with self._get_cursor() as cur: | ||
cur.execute(self._queries.delete_expired_sessions) | ||
|
||
@retry_query(max_attempts=3) | ||
def _delete_session(self, store_id: str) -> None: | ||
with self._get_cursor() as cur: | ||
cur.execute( | ||
self._queries.delete_session, | ||
dict(session_id=store_id), | ||
) | ||
|
||
@retry_query(max_attempts=3) | ||
def _retrieve_session_data(self, store_id: str) -> Optional[dict]: | ||
with self._get_cursor() as cur: | ||
cur.execute( | ||
self._queries.retrieve_session_data, | ||
dict(session_id=store_id), | ||
) | ||
session_data = cur.fetchone() | ||
|
||
if session_data is not None: | ||
serialized_session_data = want_bytes(session_data[0]) | ||
return self.serializer.loads(serialized_session_data) | ||
return None | ||
|
||
@retry_query(max_attempts=3) | ||
def _upsert_session( | ||
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str | ||
) -> None: | ||
|
||
serialized_session_data = self.serializer.dumps(session) | ||
|
||
if session.sid is not None: | ||
assert session.sid == store_id.removeprefix(self.key_prefix) | ||
|
||
with self._get_cursor() as cur: | ||
cur.execute( | ||
self._queries.upsert_session, | ||
dict( | ||
session_id=store_id, | ||
data=serialized_session_data, | ||
ttl=session_lifetime, | ||
), | ||
) | ||
|
||
def _drop_table(self) -> None: | ||
with self._get_cursor() as cur: | ||
cur.execute(self._queries.drop_sessions_table) |
Oops, something went wrong.