Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement PostgreSQL sessions #231

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5d197da
add psycopg2 dependency
giuppep Mar 27, 2024
a56e68c
setup basic structure
giuppep Mar 27, 2024
6cdc30f
copy queries from flask-pg-sessions repo
giuppep Mar 27, 2024
77046f7
implement abc methods
giuppep Mar 28, 2024
161b2c9
add imports in __init__.py
giuppep Mar 28, 2024
eef255d
pass ttl to upsert query and compute expiry time based on database time
giuppep Mar 28, 2024
4441ca6
sort imports
giuppep Mar 28, 2024
a3a4e7b
linting fixes
giuppep Mar 28, 2024
58b445a
move defaults to dedicated module
giuppep Mar 28, 2024
a9e65e3
tidy up
giuppep Mar 28, 2024
3c3dbff
fix indentation
giuppep Mar 28, 2024
e79befe
create schema and table if they don't exist
giuppep Mar 28, 2024
6d0ece6
use store_id arg in upsert_session
giuppep Mar 28, 2024
ae93eaf
fix: misunderstood ttl flag
giuppep Mar 28, 2024
24d3732
pass threadedconnectionpool to constructor rather than connection par…
giuppep Apr 3, 2024
bc3ee17
Remove leftover 'pass'
giuppep Apr 3, 2024
1ede448
docstring
giuppep Apr 3, 2024
b21760e
add postgres session to api docs
giuppep Apr 3, 2024
b8ece7f
initialisation of postrges session from flask config
giuppep Apr 3, 2024
4b19f7f
update docs requirements
giuppep Apr 3, 2024
d8ec548
query for dropping sessions table - useful for tests
giuppep Apr 3, 2024
86b5b9c
add postgres to docker compose
giuppep Apr 3, 2024
7cbfd26
basic tests for postgres sessions
giuppep Apr 3, 2024
7d68f45
bug fixes
giuppep Apr 3, 2024
297eaa1
add cookie test
giuppep Apr 3, 2024
b4f488d
add postgres service to test gh actions
giuppep Apr 3, 2024
616edd0
undo formatting changes
giuppep Apr 3, 2024
563d5a4
rename postgres -> postgresql
giuppep Apr 3, 2024
db586ab
undo formatting changes
giuppep Apr 3, 2024
38c1d81
undo formatting changes
giuppep Apr 3, 2024
ab21ea1
undo formatting changes
giuppep Apr 3, 2024
bd4623c
remove debug print statement
giuppep Apr 3, 2024
bb4259b
replace pipes in type annotations with Optional
giuppep Apr 16, 2024
1108c74
fix missing return type
giuppep Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ jobs:
image: amazon/dynamodb-local
ports:
- 8000:8000

postgresql:
image: postgres:latest
ports:
- 5433:5432
env:
POSTGRES_PASSWORD: pwd
POSTGRES_USER: root
POSTGRES_DB: dummy
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- uses: supercharge/redis-github-action@1.5.0
Expand Down
11 changes: 11 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ services:
ports:
- "11211:11211"

postgres:
image: postgres:latest
environment:
- POSTGRES_USER=root
- POSTGRES_PASSWORD=pwd
- POSTGRES_DB=dummy
ports:
- "5433:5432"
volumes:
- postgres_data:/var/lib/postgresql/data

volumes:
postgres_data:
mongo_data:
Expand Down
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Anything documented here is part of the public API that Flask-Session provides,
.. autoclass:: flask_session.cachelib.CacheLibSessionInterface
.. autoclass:: flask_session.mongodb.MongoDBSessionInterface
.. autoclass:: flask_session.sqlalchemy.SqlAlchemySessionInterface
.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface
.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface
.. autoclass:: flask_session.postgresql.PostgreSqlSessionInterface
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ dev-dependencies = [
"boto3>=1.34.68",
"mypy_boto3_dynamodb>=1.34.67",
"pymemcache>=4.0.0",
"psycopg2-binary>=2",
]
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ Flask-SQLAlchemy
pymongo
boto3
mypy_boto3_dynamodb
psycopg2-binary

3 changes: 2 additions & 1 deletion requirements/docs.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ pymongo
flask_sqlalchemy
pymemcache
boto3
mypy_boto3_dynamodb
mypy_boto3_dynamodb
psycopg2-binary
6 changes: 6 additions & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
alabaster==0.7.13
# via sphinx
async-timeout==4.0.3
# via redis
babel==2.12.1
# via sphinx
beautifulsoup4==4.12.3
Expand Down Expand Up @@ -36,6 +38,8 @@ flask-sqlalchemy==3.1.1
# via -r requirements/docs.in
furo==2024.1.29
# via -r requirements/docs.in
greenlet==3.0.3
# via sqlalchemy
idna==3.4
# via requests
imagesize==1.4.1
Expand All @@ -58,6 +62,8 @@ mypy-boto3-dynamodb==1.34.67
# via -r requirements/docs.in
packaging==23.1
# via sphinx
psycopg2-binary==2.9.9
# via -r requirements/docs.in
pygments==2.15.1
# via
# furo
Expand Down
30 changes: 27 additions & 3 deletions src/flask_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,29 @@ def _get_interface(self, app):
SESSION_SQLALCHEMY_BIND_KEY = config.get(
"SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY
)
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)

# DynamoDB settings
SESSION_DYNAMODB = config.get("SESSION_DYNAMODB", Defaults.SESSION_DYNAMODB)
SESSION_DYNAMODB_TABLE = config.get(
"SESSION_DYNAMODB_TABLE", Defaults.SESSION_DYNAMODB_TABLE
)

# PostgreSQL settings
SESSION_POSTGRESQL = config.get(
"SESSION_POSTGRESQL", Defaults.SESSION_POSTGRESQL
)
SESSION_POSTGRESQL_TABLE = config.get(
"SESSION_POSTGRESQL_TABLE", Defaults.SESSION_POSTGRESQL_TABLE
)
SESSION_POSTGRESQL_SCHEMA = config.get(
"SESSION_POSTGRESQL_SCHEMA", Defaults.SESSION_POSTGRESQL_SCHEMA
)

# Shared settings
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)

common_params = {
"app": app,
"key_prefix": SESSION_KEY_PREFIX,
Expand Down Expand Up @@ -180,6 +193,17 @@ def _get_interface(self, app):
table_name=SESSION_DYNAMODB_TABLE,
)

elif SESSION_TYPE == "postgresql":
from .postgresql import PostgreSqlSessionInterface

session_interface = PostgreSqlSessionInterface(
**common_params,
pool=SESSION_POSTGRESQL,
table=SESSION_POSTGRESQL_TABLE,
schema=SESSION_POSTGRESQL_SCHEMA,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
)

else:
raise ValueError(f"Unrecognized value for SESSION_TYPE: {SESSION_TYPE}")

Expand Down
5 changes: 5 additions & 0 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ class Defaults:
# DynamoDB settings
SESSION_DYNAMODB = None
SESSION_DYNAMODB_TABLE = "Sessions"

# PostgreSQL settings
SESSION_POSTGRESQL = None
SESSION_POSTGRESQL_TABLE = "flask_sessions"
SESSION_POSTGRESQL_SCHEMA = "public"
1 change: 1 addition & 0 deletions src/flask_session/postgresql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .postgresql import PostgreSqlSession, PostgreSqlSessionInterface # noqa: F401
84 changes: 84 additions & 0 deletions src/flask_session/postgresql/_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from psycopg2 import sql


class Queries:
def __init__(self, schema: str, table: str) -> None:
"""Class to hold all the queries used by the session interface.

Args:
schema (str): The name of the schema to use for the session data.
table (str): The name of the table to use for the session data.
"""
self.schema = schema
self.table = table

@property
def create_schema(self) -> str:
return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format(
schema=sql.Identifier(self.schema)
)

@property
def create_table(self) -> str:
uq_idx = sql.Identifier(f"uq_{self.table}_session_id")
expiry_idx = sql.Identifier(f"{self.table}_expiry_idx")
return sql.SQL(
"""CREATE TABLE IF NOT EXISTS {schema}.{table} (
session_id VARCHAR(255) NOT NULL PRIMARY KEY,
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'),
data BYTEA,
expiry TIMESTAMP WITHOUT TIME ZONE
);

--- Unique session_id
CREATE UNIQUE INDEX IF NOT EXISTS
{uq_idx} ON {schema}.{table} (session_id);

--- Index for expiry timestamp
CREATE INDEX IF NOT EXISTS
{expiry_idx} ON {schema}.{table} (expiry);"""
).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
uq_idx=uq_idx,
expiry_idx=expiry_idx,
)

@property
def retrieve_session_data(self) -> str:
return sql.SQL(
"""--- If the current sessions is expired, delete it
DELETE FROM {schema}.{table}
WHERE session_id = %(session_id)s AND expiry < NOW();
--- Else retrieve it
SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def upsert_session(self) -> str:
return sql.SQL(
"""INSERT INTO {schema}.{table} (session_id, data, expiry)
VALUES (%(session_id)s, %(data)s, NOW() + %(ttl)s)
ON CONFLICT (session_id)
DO UPDATE SET data = %(data)s, expiry = NOW() + %(ttl)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def delete_expired_sessions(self) -> str:
return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format(
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)
)

@property
def delete_session(self) -> str:
return sql.SQL(
"DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;"
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def drop_sessions_table(self) -> str:
return sql.SQL("DROP TABLE IF EXISTS {schema}.{table};").format(
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)
)
145 changes: 145 additions & 0 deletions src/flask_session/postgresql/postgresql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from contextlib import contextmanager
from datetime import timedelta as TimeDelta
from typing import Generator, Optional

from flask import Flask
from itsdangerous import want_bytes
from psycopg2.extensions import connection as PsycoPg2Connection
from psycopg2.extensions import cursor as PsycoPg2Cursor
from psycopg2.pool import ThreadedConnectionPool

from .._utils import retry_query
from ..base import ServerSideSession, ServerSideSessionInterface
from ..defaults import Defaults
from ._queries import Queries


class PostgreSqlSession(ServerSideSession):
pass


class PostgreSqlSessionInterface(ServerSideSessionInterface):
"""A Session interface that uses PostgreSQL as a session storage. (`psycopg2` required)

:param pool: A ``psycopg2.pool.ThreadedConnectionPool`` instance.
:param key_prefix: A prefix that is added to all storage keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
:param serialization_format: The serialization format to use for the session data.
:param table: The table name you want to use.
:param schema: The db schema to use.
:param cleanup_n_requests: Delete expired sessions on average every N requests.
"""

session_class = PostgreSqlSession
ttl = False

def __init__(
self,
app: Flask,
pool: Optional[ThreadedConnectionPool] = Defaults.SESSION_POSTGRESQL,
key_prefix: str = Defaults.SESSION_KEY_PREFIX,
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_ID_LENGTH,
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT,
table: str = Defaults.SESSION_POSTGRESQL_TABLE,
schema: str = Defaults.SESSION_POSTGRESQL_SCHEMA,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
) -> None:
if not isinstance(pool, ThreadedConnectionPool):
raise TypeError("No valid ThreadedConnectionPool instance provided.")

self.pool = pool

self._table = table
self._schema = schema

self._queries = Queries(schema=self._schema, table=self._table)

self._create_schema_and_table()

super().__init__(
app,
key_prefix,
use_signer,
permanent,
sid_length,
serialization_format,
cleanup_n_requests,
)

@contextmanager
def _get_cursor(
self, conn: Optional[PsycoPg2Connection] = None
) -> Generator[PsycoPg2Cursor, None, None]:
_conn: PsycoPg2Connection = conn or self.pool.getconn()

assert isinstance(_conn, PsycoPg2Connection)
try:
with _conn, _conn.cursor() as cur:
yield cur
except Exception:
raise
finally:
self.pool.putconn(_conn)

@retry_query(max_attempts=3)
def _create_schema_and_table(self) -> None:
with self._get_cursor() as cur:
cur.execute(self._queries.create_schema)
cur.execute(self._queries.create_table)

def _delete_expired_sessions(self) -> None:
"""Delete all expired sessions from the database."""
with self._get_cursor() as cur:
cur.execute(self._queries.delete_expired_sessions)

@retry_query(max_attempts=3)
def _delete_session(self, store_id: str) -> None:
with self._get_cursor() as cur:
cur.execute(
self._queries.delete_session,
dict(session_id=store_id),
)

@retry_query(max_attempts=3)
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
with self._get_cursor() as cur:
cur.execute(
self._queries.retrieve_session_data,
dict(session_id=store_id),
)
session_data = cur.fetchone()

if session_data is not None:
serialized_session_data = want_bytes(session_data[0])
return self.serializer.decode(serialized_session_data)
return None

@retry_query(max_attempts=3)
def _upsert_session(
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str
) -> None:

serialized_session_data = self.serializer.encode(session)

if session.sid is not None:
assert session.sid == store_id.removeprefix(self.key_prefix)

with self._get_cursor() as cur:
cur.execute(
self._queries.upsert_session,
dict(
session_id=store_id,
data=serialized_session_data,
ttl=session_lifetime,
),
)

def _drop_table(self) -> None:
with self._get_cursor() as cur:
cur.execute(self._queries.drop_sessions_table)
Loading
Loading