Skip to content

Commit

Permalink
add expiration condition to milestonegroup update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaHaWo committed Dec 5, 2024
1 parent 66136f0 commit 36449b6
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions mondey_backend/src/mondey_backend/routers/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence

import numpy as np
from sqlalchemy import and_
from sqlmodel import col
from sqlmodel import select

Expand Down Expand Up @@ -40,16 +41,20 @@ def _finalize_statistics(
mean: float | np.ndarray[float],
m2: float | np.ndarray[float],
) -> tuple[int | np.ndarray[int], float | np.ndarray[float], float | np.ndarray[float]]:
if isinstance(count, int):
if isinstance(count, int) and isinstance(mean, float) and isinstance(m2, float):
if count < 2:
return count, mean, 0.0
else:
variance = m2 / (count - 1)
return count, mean, np.sqrt(variance)
elif isinstance(count, np.ndarray):
var: float = m2 / (count - 1)
return count, mean, np.sqrt(var)
elif (
isinstance(count, np.ndarray)
and isinstance(mean, np.ndarray)
and isinstance(m2, np.ndarray)
):
with np.errstate(invalid="ignore"):
valid_counts = count >= 2
variance = m2
variance: np.ndarray = m2
variance[valid_counts] /= count[valid_counts] - 1
variance[np.invert(valid_counts)] = 0.0
return count, np.nan_to_num(mean), np.nan_to_num(np.sqrt(variance))
Expand Down Expand Up @@ -91,9 +96,8 @@ def _get_statistics_by_age(


def calculate_milestone_statistics_by_age(
session: SessionDep,
milestone_id: int,
) -> MilestoneAgeScoreCollection:
session: SessionDep, milestone_id: int, session_expired_days: int = 7
) -> MilestoneAgeScoreCollection | None:
# get the newest statistics for the milestone
last_statistics = session.exec(
select(MilestoneAgeScoreCollection)
Expand All @@ -112,22 +116,32 @@ def calculate_milestone_statistics_by_age(
stddev_scores = np.array([score.stddev_score for score in last_scores])

child_ages = _get_answer_session_child_ages_in_months(session)
expiration_data = datetime.datetime.now() - datetime.timedelta(
days=session_expired_days
)

if last_statistics is None:
answers_query = select(MilestoneAnswer).where(
MilestoneAnswer.milestone_id == milestone_id
# no statistics exists yet -> all answers from expired sessions are relevant
answers_query = (
select(MilestoneAnswer)
.where(MilestoneAnswer.milestone_id == milestone_id)
.where(MilestoneAnswerSession.created_at < expiration_data)
)
else:
# fetch all answers that have been added in answersessions after the last statistics were calculated
# we calculate the statistics with an online algorithm, so we only consider new data
# that has not been included in the last statistics but which stems from sessions that are expired
answers_query = (
select(MilestoneAnswer)
.join(
MilestoneAnswerSession,
MilestoneAnswer.answer_session_id == MilestoneAnswerSession.id,
)
.where(MilestoneAnswer.milestone_id == milestone_id)
.where(
MilestoneAnswer.milestone_id == milestone_id,
MilestoneAnswerSession.created_at > last_statistics.created_at,
and_(
MilestoneAnswerSession.created_at > last_statistics.created_at,
MilestoneAnswerSession.created_at <= expiration_data,
) # expired session only which are not in the last statistics
)
)
answers = session.exec(answers_query).all()
Expand All @@ -139,7 +153,7 @@ def calculate_milestone_statistics_by_age(
)
expected_age = _get_expected_age_from_scores(avg_scores)

# overwrite last_statistics with updated stuff
# overwrite last_statistics with updated stuff --> set primary keys explicitly
return MilestoneAgeScoreCollection(
milestone_id=milestone_id,
expected_age=expected_age,
Expand All @@ -161,9 +175,8 @@ def calculate_milestone_statistics_by_age(


def calculate_milestonegroup_statistics_by_age(
session: SessionDep,
milestonegroup_id,
) -> MilestoneGroupAgeScoreCollection:
session: SessionDep, milestonegroup_id, session_expired_days: int = 7
) -> MilestoneGroupAgeScoreCollection | None:
# get the newest statistics for the milestonegroup
last_statistics = session.exec(
select(MilestoneGroupAgeScoreCollection)
Expand All @@ -189,25 +202,36 @@ def calculate_milestonegroup_statistics_by_age(
)

child_ages = _get_answer_session_child_ages_in_months(session)

expiration_data = datetime.datetime.now() - datetime.timedelta(
days=session_expired_days
)
if last_statistics is None:
answer_query = select(MilestoneAnswer).where(
col(MilestoneAnswer.milestone_group_id) == milestonegroup_id
# no statistics exists yet -> all answers from expired sessions are relevant
answer_query = (
select(MilestoneAnswer)
.where(col(MilestoneAnswer.milestone_group_id) == milestonegroup_id)
.where(
MilestoneAnswerSession.created_at
< expiration_data # expired session only
)
)
else:
# fetch all answers that have been added in answersessions after the last statistics were calculated
# we calculate the statistics with an online algorithm, so we only consider new data
# that has not been included in the last statistics but which stems from sessions that are expired
answer_query = (
select(MilestoneAnswer)
.join(
MilestoneAnswerSession,
MilestoneAnswer.answer_session_id == MilestoneAnswerSession.id,
)
.where(MilestoneAnswer.milestone_group_id == milestonegroup_id)
.where(
MilestoneAnswer.milestone_group_id == milestonegroup_id,
MilestoneAnswerSession.created_at > last_statistics.created_at,
)
and_(
MilestoneAnswerSession.created_at > last_statistics.created_at,
MilestoneAnswerSession.created_at <= expiration_data,
)
) # expired session only which are not in the last statistics
)

answers = session.exec(answer_query).all()
if len(answers) == 0:
return last_statistics
Expand Down

0 comments on commit 36449b6

Please sign in to comment.