Skip to content

Commit

Permalink
Fix subscription state detection for users based on phone numbers, em…
Browse files Browse the repository at this point in the history
…ails (#633)

* Fix subscription state detection for users based on phone numbers, emails
* Fix unit tests for api_user4
* Use a single method for determining subscription from user
* Pass user object, rather than user.email for getting subscription state
  • Loading branch information
sabaimran authored Jan 31, 2024
1 parent fc4b57d commit 4daac33
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def authenticate(self, request: HTTPConnection):
if subscribed:
return (
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user),
AuthenticatedKhojUser(user, client_application),
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
if state.anonymous_mode:
Expand Down
6 changes: 4 additions & 2 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def subscription_to_state(subscription: Subscription) -> str:
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date is None:
return SubscriptionState.EXPIRED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
Expand All @@ -222,11 +224,11 @@ def get_user_subscription_state(email: str) -> str:
return subscription_to_state(user_subscription)


async def aget_user_subscription_state(email: str) -> str:
async def aget_user_subscription_state(user: KhojUser) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
user_subscription = await Subscription.objects.filter(user=user).afirst()
return subscription_to_state(user_subscription)


Expand Down

0 comments on commit 4daac33

Please sign in to comment.